MLIR 22.0.0git
LoopAnnotationImporter.cpp
Go to the documentation of this file.
1//===- LoopAnnotationImporter.cpp - Loop annotation import ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "llvm/IR/Constants.h"
11
12using namespace mlir;
13using namespace mlir::LLVM;
14using namespace mlir::LLVM::detail;
15
16namespace {
17/// Helper class that keeps the state of one metadata to attribute conversion.
18struct LoopMetadataConversion {
19 LoopMetadataConversion(const llvm::MDNode *node, Location loc,
20 LoopAnnotationImporter &loopAnnotationImporter)
21 : node(node), loc(loc), loopAnnotationImporter(loopAnnotationImporter),
22 ctx(loc->getContext()){};
23 /// Converts this structs loop metadata node into a LoopAnnotationAttr.
24 LoopAnnotationAttr convert();
25
26 /// Initializes the shared state for the conversion member functions.
27 LogicalResult initConversionState();
28
29 /// Helper function to get and erase a property.
30 const llvm::MDNode *lookupAndEraseProperty(StringRef name);
31
32 /// Helper functions to lookup and convert MDNodes into a specifc attribute
33 /// kind. These functions return null-attributes if there is no node with the
34 /// specified name, or failure, if the node is ill-formatted.
35 FailureOr<BoolAttr> lookupUnitNode(StringRef name);
36 FailureOr<BoolAttr> lookupBoolNode(StringRef name, bool negated = false);
37 FailureOr<BoolAttr> lookupIntNodeAsBoolAttr(StringRef name);
38 FailureOr<IntegerAttr> lookupIntNode(StringRef name);
39 FailureOr<llvm::MDNode *> lookupMDNode(StringRef name);
40 FailureOr<SmallVector<llvm::MDNode *>> lookupMDNodes(StringRef name);
41 FailureOr<LoopAnnotationAttr> lookupFollowupNode(StringRef name);
42 FailureOr<BoolAttr> lookupBooleanUnitNode(StringRef enableName,
43 StringRef disableName,
44 bool negated = false);
45
46 /// Conversion functions for sub-attributes.
47 FailureOr<LoopVectorizeAttr> convertVectorizeAttr();
48 FailureOr<LoopInterleaveAttr> convertInterleaveAttr();
49 FailureOr<LoopUnrollAttr> convertUnrollAttr();
50 FailureOr<LoopUnrollAndJamAttr> convertUnrollAndJamAttr();
51 FailureOr<LoopLICMAttr> convertLICMAttr();
52 FailureOr<LoopDistributeAttr> convertDistributeAttr();
53 FailureOr<LoopPipelineAttr> convertPipelineAttr();
54 FailureOr<LoopPeeledAttr> convertPeeledAttr();
55 FailureOr<LoopUnswitchAttr> convertUnswitchAttr();
56 FailureOr<SmallVector<AccessGroupAttr>> convertParallelAccesses();
57 FusedLoc convertStartLoc();
58 FailureOr<FusedLoc> convertEndLoc();
59
60 llvm::SmallVector<llvm::DILocation *, 2> locations;
61 llvm::StringMap<const llvm::MDNode *> propertyMap;
62 const llvm::MDNode *node;
63 Location loc;
64 LoopAnnotationImporter &loopAnnotationImporter;
65 MLIRContext *ctx;
66};
67} // namespace
68
69LogicalResult LoopMetadataConversion::initConversionState() {
70 // Check if it's a valid node.
71 if (node->getNumOperands() == 0 ||
72 dyn_cast<llvm::MDNode>(node->getOperand(0)) != node)
73 return emitWarning(loc) << "invalid loop node";
74
75 for (const llvm::MDOperand &operand : llvm::drop_begin(node->operands())) {
76 if (auto *diLoc = dyn_cast<llvm::DILocation>(operand)) {
77 locations.push_back(diLoc);
78 continue;
79 }
80
81 auto *property = dyn_cast<llvm::MDNode>(operand);
82 if (!property)
83 return emitWarning(loc) << "expected all loop properties to be either "
84 "debug locations or metadata nodes";
85
86 if (property->getNumOperands() == 0)
87 return emitWarning(loc) << "cannot import empty loop property";
88
89 auto *nameNode = dyn_cast<llvm::MDString>(property->getOperand(0));
90 if (!nameNode)
91 return emitWarning(loc) << "cannot import loop property without a name";
92 StringRef name = nameNode->getString();
93
94 bool succ = propertyMap.try_emplace(name, property).second;
95 if (!succ)
96 return emitWarning(loc)
97 << "cannot import loop properties with duplicated names " << name;
98 }
99
100 return success();
101}
102
103const llvm::MDNode *
104LoopMetadataConversion::lookupAndEraseProperty(StringRef name) {
105 auto it = propertyMap.find(name);
106 if (it == propertyMap.end())
107 return nullptr;
108 const llvm::MDNode *property = it->getValue();
109 propertyMap.erase(it);
110 return property;
111}
112
113FailureOr<BoolAttr> LoopMetadataConversion::lookupUnitNode(StringRef name) {
114 const llvm::MDNode *property = lookupAndEraseProperty(name);
115 if (!property)
116 return BoolAttr(nullptr);
117
118 if (property->getNumOperands() != 1)
119 return emitWarning(loc)
120 << "expected metadata node " << name << " to hold no value";
121
122 return BoolAttr::get(ctx, true);
123}
124
125FailureOr<BoolAttr> LoopMetadataConversion::lookupBooleanUnitNode(
126 StringRef enableName, StringRef disableName, bool negated) {
127 auto enable = lookupUnitNode(enableName);
128 auto disable = lookupUnitNode(disableName);
129 if (failed(enable) || failed(disable))
130 return failure();
131
132 if (*enable && *disable)
133 return emitWarning(loc)
134 << "expected metadata nodes " << enableName << " and " << disableName
135 << " to be mutually exclusive.";
136
137 if (*enable)
138 return BoolAttr::get(ctx, !negated);
139
140 if (*disable)
141 return BoolAttr::get(ctx, negated);
142 return BoolAttr(nullptr);
143}
144
145FailureOr<BoolAttr> LoopMetadataConversion::lookupBoolNode(StringRef name,
146 bool negated) {
147 const llvm::MDNode *property = lookupAndEraseProperty(name);
148 if (!property)
149 return BoolAttr(nullptr);
150
151 auto emitNodeWarning = [&]() {
152 return emitWarning(loc)
153 << "expected metadata node " << name << " to hold a boolean value";
154 };
155
156 if (property->getNumOperands() != 2)
157 return emitNodeWarning();
158 llvm::ConstantInt *val =
159 llvm::mdconst::dyn_extract<llvm::ConstantInt>(property->getOperand(1));
160 if (!val || val->getBitWidth() != 1)
161 return emitNodeWarning();
162
163 return BoolAttr::get(ctx, val->getValue().getLimitedValue(1) ^ negated);
164}
165
166FailureOr<BoolAttr>
167LoopMetadataConversion::lookupIntNodeAsBoolAttr(StringRef name) {
168 const llvm::MDNode *property = lookupAndEraseProperty(name);
169 if (!property)
170 return BoolAttr(nullptr);
171
172 auto emitNodeWarning = [&]() {
173 return emitWarning(loc)
174 << "expected metadata node " << name << " to hold an integer value";
175 };
176
177 if (property->getNumOperands() != 2)
178 return emitNodeWarning();
179 llvm::ConstantInt *val =
180 llvm::mdconst::dyn_extract<llvm::ConstantInt>(property->getOperand(1));
181 if (!val || val->getBitWidth() != 32)
182 return emitNodeWarning();
183
184 return BoolAttr::get(ctx, val->getValue().getLimitedValue(1));
185}
186
187FailureOr<IntegerAttr> LoopMetadataConversion::lookupIntNode(StringRef name) {
188 const llvm::MDNode *property = lookupAndEraseProperty(name);
189 if (!property)
190 return IntegerAttr(nullptr);
191
192 auto emitNodeWarning = [&]() {
193 return emitWarning(loc)
194 << "expected metadata node " << name << " to hold an i32 value";
195 };
196
197 if (property->getNumOperands() != 2)
198 return emitNodeWarning();
199
200 llvm::ConstantInt *val =
201 llvm::mdconst::dyn_extract<llvm::ConstantInt>(property->getOperand(1));
202 if (!val || val->getBitWidth() != 32)
203 return emitNodeWarning();
204
205 return IntegerAttr::get(IntegerType::get(ctx, 32),
206 val->getValue().getLimitedValue());
207}
208
209FailureOr<llvm::MDNode *> LoopMetadataConversion::lookupMDNode(StringRef name) {
210 const llvm::MDNode *property = lookupAndEraseProperty(name);
211 if (!property)
212 return nullptr;
213
214 auto emitNodeWarning = [&]() {
215 return emitWarning(loc)
216 << "expected metadata node " << name << " to hold an MDNode";
217 };
218
219 if (property->getNumOperands() != 2)
220 return emitNodeWarning();
221
222 auto *node = dyn_cast<llvm::MDNode>(property->getOperand(1));
223 if (!node)
224 return emitNodeWarning();
225
226 return node;
227}
228
229FailureOr<SmallVector<llvm::MDNode *>>
230LoopMetadataConversion::lookupMDNodes(StringRef name) {
231 const llvm::MDNode *property = lookupAndEraseProperty(name);
232 SmallVector<llvm::MDNode *> res;
233 if (!property)
234 return res;
235
236 auto emitNodeWarning = [&]() {
237 return emitWarning(loc) << "expected metadata node " << name
238 << " to hold one or multiple MDNodes";
239 };
240
241 if (property->getNumOperands() < 2)
242 return emitNodeWarning();
243
244 for (unsigned i = 1, e = property->getNumOperands(); i < e; ++i) {
245 auto *node = dyn_cast<llvm::MDNode>(property->getOperand(i));
246 if (!node)
247 return emitNodeWarning();
248 res.push_back(node);
249 }
250
251 return res;
252}
253
254FailureOr<LoopAnnotationAttr>
255LoopMetadataConversion::lookupFollowupNode(StringRef name) {
256 auto node = lookupMDNode(name);
257 if (failed(node))
258 return failure();
259 if (*node == nullptr)
260 return LoopAnnotationAttr(nullptr);
261
262 return loopAnnotationImporter.translateLoopAnnotation(*node, loc);
263}
264
265static bool isEmptyOrNull(const Attribute attr) { return !attr; }
266
267template <typename T>
268static bool isEmptyOrNull(const SmallVectorImpl<T> &vec) {
269 return vec.empty();
270}
271
272/// Helper function that only creates and attribute of type T if all argument
273/// conversion were successfull and at least one of them holds a non-null value.
274template <typename T, typename... P>
275static T createIfNonNull(MLIRContext *ctx, const P &...args) {
276 bool anyFailed = (failed(args) || ...);
277 if (anyFailed)
278 return {};
279
280 bool allEmpty = (isEmptyOrNull(*args) && ...);
281 if (allEmpty)
282 return {};
283
284 return T::get(ctx, *args...);
285}
286
287FailureOr<LoopVectorizeAttr> LoopMetadataConversion::convertVectorizeAttr() {
288 FailureOr<BoolAttr> enable =
289 lookupBoolNode("llvm.loop.vectorize.enable", true);
290 FailureOr<BoolAttr> predicateEnable =
291 lookupBoolNode("llvm.loop.vectorize.predicate.enable");
292 FailureOr<BoolAttr> scalableEnable =
293 lookupBoolNode("llvm.loop.vectorize.scalable.enable");
294 FailureOr<IntegerAttr> width = lookupIntNode("llvm.loop.vectorize.width");
295 FailureOr<LoopAnnotationAttr> followupVec =
296 lookupFollowupNode("llvm.loop.vectorize.followup_vectorized");
297 FailureOr<LoopAnnotationAttr> followupEpi =
298 lookupFollowupNode("llvm.loop.vectorize.followup_epilogue");
299 FailureOr<LoopAnnotationAttr> followupAll =
300 lookupFollowupNode("llvm.loop.vectorize.followup_all");
301
302 return createIfNonNull<LoopVectorizeAttr>(ctx, enable, predicateEnable,
303 scalableEnable, width, followupVec,
304 followupEpi, followupAll);
305}
306
307FailureOr<LoopInterleaveAttr> LoopMetadataConversion::convertInterleaveAttr() {
308 FailureOr<IntegerAttr> count = lookupIntNode("llvm.loop.interleave.count");
309 return createIfNonNull<LoopInterleaveAttr>(ctx, count);
310}
311
312FailureOr<LoopUnrollAttr> LoopMetadataConversion::convertUnrollAttr() {
313 FailureOr<BoolAttr> disable = lookupBooleanUnitNode(
314 "llvm.loop.unroll.enable", "llvm.loop.unroll.disable", /*negated=*/true);
315 FailureOr<IntegerAttr> count = lookupIntNode("llvm.loop.unroll.count");
316 FailureOr<BoolAttr> runtimeDisable =
317 lookupUnitNode("llvm.loop.unroll.runtime.disable");
318 FailureOr<BoolAttr> full = lookupUnitNode("llvm.loop.unroll.full");
319 FailureOr<LoopAnnotationAttr> followupUnrolled =
320 lookupFollowupNode("llvm.loop.unroll.followup_unrolled");
321 FailureOr<LoopAnnotationAttr> followupRemainder =
322 lookupFollowupNode("llvm.loop.unroll.followup_remainder");
323 FailureOr<LoopAnnotationAttr> followupAll =
324 lookupFollowupNode("llvm.loop.unroll.followup_all");
325
326 return createIfNonNull<LoopUnrollAttr>(ctx, disable, count, runtimeDisable,
327 full, followupUnrolled,
328 followupRemainder, followupAll);
329}
330
331FailureOr<LoopUnrollAndJamAttr>
332LoopMetadataConversion::convertUnrollAndJamAttr() {
333 FailureOr<BoolAttr> disable = lookupBooleanUnitNode(
334 "llvm.loop.unroll_and_jam.enable", "llvm.loop.unroll_and_jam.disable",
335 /*negated=*/true);
336 FailureOr<IntegerAttr> count =
337 lookupIntNode("llvm.loop.unroll_and_jam.count");
338 FailureOr<LoopAnnotationAttr> followupOuter =
339 lookupFollowupNode("llvm.loop.unroll_and_jam.followup_outer");
340 FailureOr<LoopAnnotationAttr> followupInner =
341 lookupFollowupNode("llvm.loop.unroll_and_jam.followup_inner");
342 FailureOr<LoopAnnotationAttr> followupRemainderOuter =
343 lookupFollowupNode("llvm.loop.unroll_and_jam.followup_remainder_outer");
344 FailureOr<LoopAnnotationAttr> followupRemainderInner =
345 lookupFollowupNode("llvm.loop.unroll_and_jam.followup_remainder_inner");
346 FailureOr<LoopAnnotationAttr> followupAll =
347 lookupFollowupNode("llvm.loop.unroll_and_jam.followup_all");
349 ctx, disable, count, followupOuter, followupInner, followupRemainderOuter,
350 followupRemainderInner, followupAll);
351}
352
353FailureOr<LoopLICMAttr> LoopMetadataConversion::convertLICMAttr() {
354 FailureOr<BoolAttr> disable = lookupUnitNode("llvm.licm.disable");
355 FailureOr<BoolAttr> versioningDisable =
356 lookupUnitNode("llvm.loop.licm_versioning.disable");
357 return createIfNonNull<LoopLICMAttr>(ctx, disable, versioningDisable);
358}
359
360FailureOr<LoopDistributeAttr> LoopMetadataConversion::convertDistributeAttr() {
361 FailureOr<BoolAttr> disable =
362 lookupBoolNode("llvm.loop.distribute.enable", true);
363 FailureOr<LoopAnnotationAttr> followupCoincident =
364 lookupFollowupNode("llvm.loop.distribute.followup_coincident");
365 FailureOr<LoopAnnotationAttr> followupSequential =
366 lookupFollowupNode("llvm.loop.distribute.followup_sequential");
367 FailureOr<LoopAnnotationAttr> followupFallback =
368 lookupFollowupNode("llvm.loop.distribute.followup_fallback");
369 FailureOr<LoopAnnotationAttr> followupAll =
370 lookupFollowupNode("llvm.loop.distribute.followup_all");
371 return createIfNonNull<LoopDistributeAttr>(ctx, disable, followupCoincident,
372 followupSequential,
373 followupFallback, followupAll);
374}
375
376FailureOr<LoopPipelineAttr> LoopMetadataConversion::convertPipelineAttr() {
377 FailureOr<BoolAttr> disable = lookupBoolNode("llvm.loop.pipeline.disable");
378 FailureOr<IntegerAttr> initiationinterval =
379 lookupIntNode("llvm.loop.pipeline.initiationinterval");
380 return createIfNonNull<LoopPipelineAttr>(ctx, disable, initiationinterval);
381}
382
383FailureOr<LoopPeeledAttr> LoopMetadataConversion::convertPeeledAttr() {
384 FailureOr<IntegerAttr> count = lookupIntNode("llvm.loop.peeled.count");
385 return createIfNonNull<LoopPeeledAttr>(ctx, count);
386}
387
388FailureOr<LoopUnswitchAttr> LoopMetadataConversion::convertUnswitchAttr() {
389 FailureOr<BoolAttr> partialDisable =
390 lookupUnitNode("llvm.loop.unswitch.partial.disable");
391 return createIfNonNull<LoopUnswitchAttr>(ctx, partialDisable);
392}
393
394FailureOr<SmallVector<AccessGroupAttr>>
395LoopMetadataConversion::convertParallelAccesses() {
396 FailureOr<SmallVector<llvm::MDNode *>> nodes =
397 lookupMDNodes("llvm.loop.parallel_accesses");
398 if (failed(nodes))
399 return failure();
400 SmallVector<AccessGroupAttr> refs;
401 for (llvm::MDNode *node : *nodes) {
402 FailureOr<SmallVector<AccessGroupAttr>> accessGroups =
403 loopAnnotationImporter.lookupAccessGroupAttrs(node);
404 if (failed(accessGroups)) {
405 emitWarning(loc) << "could not lookup access group";
406 continue;
407 }
408 llvm::append_range(refs, *accessGroups);
409 }
410 return refs;
411}
412
413FusedLoc LoopMetadataConversion::convertStartLoc() {
414 if (locations.empty())
415 return {};
416 return dyn_cast<FusedLoc>(
417 loopAnnotationImporter.moduleImport.translateLoc(locations[0]));
418}
419
420FailureOr<FusedLoc> LoopMetadataConversion::convertEndLoc() {
421 if (locations.size() < 2)
422 return FusedLoc();
423 if (locations.size() > 2)
424 return emitError(loc)
425 << "expected loop metadata to have at most two DILocations";
426 return dyn_cast<FusedLoc>(
427 loopAnnotationImporter.moduleImport.translateLoc(locations[1]));
428}
429
430LoopAnnotationAttr LoopMetadataConversion::convert() {
431 if (failed(initConversionState()))
432 return {};
433
434 FailureOr<BoolAttr> disableNonForced =
435 lookupUnitNode("llvm.loop.disable_nonforced");
436 FailureOr<LoopVectorizeAttr> vecAttr = convertVectorizeAttr();
437 FailureOr<LoopInterleaveAttr> interleaveAttr = convertInterleaveAttr();
438 FailureOr<LoopUnrollAttr> unrollAttr = convertUnrollAttr();
439 FailureOr<LoopUnrollAndJamAttr> unrollAndJamAttr = convertUnrollAndJamAttr();
440 FailureOr<LoopLICMAttr> licmAttr = convertLICMAttr();
441 FailureOr<LoopDistributeAttr> distributeAttr = convertDistributeAttr();
442 FailureOr<LoopPipelineAttr> pipelineAttr = convertPipelineAttr();
443 FailureOr<LoopPeeledAttr> peeledAttr = convertPeeledAttr();
444 FailureOr<LoopUnswitchAttr> unswitchAttr = convertUnswitchAttr();
445 FailureOr<BoolAttr> mustProgress = lookupUnitNode("llvm.loop.mustprogress");
446 FailureOr<BoolAttr> isVectorized =
447 lookupIntNodeAsBoolAttr("llvm.loop.isvectorized");
448 FailureOr<SmallVector<AccessGroupAttr>> parallelAccesses =
449 convertParallelAccesses();
450
451 // Drop the metadata if there are parts that cannot be imported.
452 if (!propertyMap.empty()) {
453 for (auto name : propertyMap.keys())
454 emitWarning(loc) << "unknown loop annotation " << name;
455 return {};
456 }
457
458 FailureOr<FusedLoc> startLoc = convertStartLoc();
459 FailureOr<FusedLoc> endLoc = convertEndLoc();
460
462 ctx, disableNonForced, vecAttr, interleaveAttr, unrollAttr,
463 unrollAndJamAttr, licmAttr, distributeAttr, pipelineAttr, peeledAttr,
464 unswitchAttr, mustProgress, isVectorized, startLoc, endLoc,
465 parallelAccesses);
466}
467
468LoopAnnotationAttr
470 Location loc) {
471 if (!node)
472 return {};
473
474 // Note: This check is necessary to distinguish between failed translations
475 // and not yet attempted translations.
476 auto it = loopMetadataMapping.find(node);
477 if (it != loopMetadataMapping.end())
478 return it->getSecond();
479
480 LoopAnnotationAttr attr = LoopMetadataConversion(node, loc, *this).convert();
481
482 mapLoopMetadata(node, attr);
483 return attr;
484}
485
486LogicalResult
488 Location loc) {
490 if (!node->getNumOperands())
491 accessGroups.push_back(node);
492 for (const llvm::MDOperand &operand : node->operands()) {
493 auto *childNode = dyn_cast<llvm::MDNode>(operand);
494 if (!childNode)
495 return failure();
496 accessGroups.push_back(cast<llvm::MDNode>(operand.get()));
497 }
498
499 // Convert all entries of the access group list to access group operations.
500 for (const llvm::MDNode *accessGroup : accessGroups) {
501 if (accessGroupMapping.count(accessGroup))
502 continue;
503 // Verify the access group node is distinct and empty.
504 if (accessGroup->getNumOperands() != 0 || !accessGroup->isDistinct())
505 return emitWarning(loc)
506 << "expected an access group node to be empty and distinct";
507
508 // Add a mapping from the access group node to the newly created attribute.
509 accessGroupMapping[accessGroup] = builder.getAttr<AccessGroupAttr>();
510 }
511 return success();
512}
513
514FailureOr<SmallVector<AccessGroupAttr>>
515LoopAnnotationImporter::lookupAccessGroupAttrs(const llvm::MDNode *node) const {
516 // An access group node is either a single access group or an access group
517 // list.
518 SmallVector<AccessGroupAttr> accessGroups;
519 if (!node->getNumOperands())
520 accessGroups.push_back(accessGroupMapping.lookup(node));
521 for (const llvm::MDOperand &operand : node->operands()) {
522 auto *node = cast<llvm::MDNode>(operand.get());
523 accessGroups.push_back(accessGroupMapping.lookup(node));
524 }
525 // Exit if one of the access group node lookups failed.
526 if (llvm::is_contained(accessGroups, nullptr))
527 return failure();
528 return accessGroups;
529}
return success()
static T createIfNonNull(MLIRContext *ctx, const P &...args)
Helper function that only creates and attribute of type T if all argument conversion were successfull...
static bool isEmptyOrNull(const Attribute attr)
b getContext())
Attributes are known-constant values of operations.
Definition Attributes.h:25
static BoolAttr get(MLIRContext *context, bool value)
Location translateLoc(llvm::DILocation *loc)
Translates the debug location.
LoopAnnotationAttr translateLoopAnnotation(const llvm::MDNode *node, Location loc)
LogicalResult translateAccessGroup(const llvm::MDNode *node, Location loc)
Converts all LLVM access groups starting from node to MLIR access group attributes.
ModuleImport & moduleImport
The ModuleImport owning this instance.
FailureOr< SmallVector< AccessGroupAttr > > lookupAccessGroupAttrs(const llvm::MDNode *node) const
Returns the access group attribute that map to the access group nodes starting from the access group ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.