10 #include "llvm/IR/DebugInfoMetadata.h"
18 struct LoopAnnotationConversion {
19 LoopAnnotationConversion(LoopAnnotationAttr attr,
Operation *op,
21 llvm::LLVMContext &ctx)
23 loopAnnotationTranslation(loopAnnotationTranslation), ctx(ctx) {}
27 llvm::MDNode *convert();
30 void addUnitNode(StringRef name);
31 void addUnitNode(StringRef name,
BoolAttr attr);
32 void addI32NodeWithVal(StringRef name, uint32_t val);
33 void convertBoolNode(StringRef name,
BoolAttr attr,
bool negated =
false);
34 void convertI32Node(StringRef name, IntegerAttr attr);
35 void convertFollowupNode(StringRef name, LoopAnnotationAttr attr);
39 void convertLoopOptions(LoopVectorizeAttr
options);
40 void convertLoopOptions(LoopInterleaveAttr
options);
41 void convertLoopOptions(LoopUnrollAttr
options);
42 void convertLoopOptions(LoopUnrollAndJamAttr
options);
43 void convertLoopOptions(LoopLICMAttr
options);
44 void convertLoopOptions(LoopDistributeAttr
options);
45 void convertLoopOptions(LoopPipelineAttr
options);
46 void convertLoopOptions(LoopPeeledAttr
options);
47 void convertLoopOptions(LoopUnswitchAttr
options);
49 LoopAnnotationAttr attr;
52 llvm::LLVMContext &ctx;
57 void LoopAnnotationConversion::addUnitNode(StringRef name) {
58 metadataNodes.push_back(
62 void LoopAnnotationConversion::addUnitNode(StringRef name,
BoolAttr attr) {
67 void LoopAnnotationConversion::addI32NodeWithVal(StringRef name, uint32_t val) {
70 metadataNodes.push_back(
75 void LoopAnnotationConversion::convertBoolNode(StringRef name,
BoolAttr attr,
79 bool val = negated ^ attr.
getValue();
80 llvm::Constant *cstValue = llvm::ConstantInt::getBool(ctx, val);
81 metadataNodes.push_back(
86 void LoopAnnotationConversion::convertI32Node(StringRef name,
90 addI32NodeWithVal(name, attr.getInt());
93 void LoopAnnotationConversion::convertFollowupNode(StringRef name,
94 LoopAnnotationAttr attr) {
99 loopAnnotationTranslation.translateLoopAnnotation(attr, op);
101 metadataNodes.push_back(
105 void LoopAnnotationConversion::convertLoopOptions(LoopVectorizeAttr
options) {
106 convertBoolNode(
"llvm.loop.vectorize.enable",
options.getDisable(),
true);
107 convertBoolNode(
"llvm.loop.vectorize.predicate.enable",
109 convertBoolNode(
"llvm.loop.vectorize.scalable.enable",
111 convertI32Node(
"llvm.loop.vectorize.width",
options.getWidth());
112 convertFollowupNode(
"llvm.loop.vectorize.followup_vectorized",
113 options.getFollowupVectorized());
114 convertFollowupNode(
"llvm.loop.vectorize.followup_epilogue",
115 options.getFollowupEpilogue());
116 convertFollowupNode(
"llvm.loop.vectorize.followup_all",
120 void LoopAnnotationConversion::convertLoopOptions(LoopInterleaveAttr
options) {
121 convertI32Node(
"llvm.loop.interleave.count",
options.getCount());
124 void LoopAnnotationConversion::convertLoopOptions(LoopUnrollAttr
options) {
125 if (
auto disable =
options.getDisable())
126 addUnitNode(disable.getValue() ?
"llvm.loop.unroll.disable"
127 :
"llvm.loop.unroll.enable");
128 convertI32Node(
"llvm.loop.unroll.count",
options.getCount());
129 convertBoolNode(
"llvm.loop.unroll.runtime.disable",
131 addUnitNode(
"llvm.loop.unroll.full",
options.getFull());
132 convertFollowupNode(
"llvm.loop.unroll.followup_unrolled",
133 options.getFollowupUnrolled());
134 convertFollowupNode(
"llvm.loop.unroll.followup_remainder",
135 options.getFollowupRemainder());
136 convertFollowupNode(
"llvm.loop.unroll.followup_all",
140 void LoopAnnotationConversion::convertLoopOptions(
141 LoopUnrollAndJamAttr
options) {
142 if (
auto disable =
options.getDisable())
143 addUnitNode(disable.getValue() ?
"llvm.loop.unroll_and_jam.disable"
144 :
"llvm.loop.unroll_and_jam.enable");
145 convertI32Node(
"llvm.loop.unroll_and_jam.count",
options.getCount());
146 convertFollowupNode(
"llvm.loop.unroll_and_jam.followup_outer",
148 convertFollowupNode(
"llvm.loop.unroll_and_jam.followup_inner",
150 convertFollowupNode(
"llvm.loop.unroll_and_jam.followup_remainder_outer",
151 options.getFollowupRemainderOuter());
152 convertFollowupNode(
"llvm.loop.unroll_and_jam.followup_remainder_inner",
153 options.getFollowupRemainderInner());
154 convertFollowupNode(
"llvm.loop.unroll_and_jam.followup_all",
158 void LoopAnnotationConversion::convertLoopOptions(LoopLICMAttr
options) {
159 addUnitNode(
"llvm.licm.disable",
options.getDisable());
160 addUnitNode(
"llvm.loop.licm_versioning.disable",
161 options.getVersioningDisable());
164 void LoopAnnotationConversion::convertLoopOptions(LoopDistributeAttr
options) {
165 convertBoolNode(
"llvm.loop.distribute.enable",
options.getDisable(),
true);
166 convertFollowupNode(
"llvm.loop.distribute.followup_coincident",
167 options.getFollowupCoincident());
168 convertFollowupNode(
"llvm.loop.distribute.followup_sequential",
169 options.getFollowupSequential());
170 convertFollowupNode(
"llvm.loop.distribute.followup_fallback",
171 options.getFollowupFallback());
172 convertFollowupNode(
"llvm.loop.distribute.followup_all",
176 void LoopAnnotationConversion::convertLoopOptions(LoopPipelineAttr
options) {
177 convertBoolNode(
"llvm.loop.pipeline.disable",
options.getDisable());
178 convertI32Node(
"llvm.loop.pipeline.initiationinterval",
179 options.getInitiationinterval());
182 void LoopAnnotationConversion::convertLoopOptions(LoopPeeledAttr
options) {
183 convertI32Node(
"llvm.loop.peeled.count",
options.getCount());
186 void LoopAnnotationConversion::convertLoopOptions(LoopUnswitchAttr
options) {
187 addUnitNode(
"llvm.loop.unswitch.partial.disable",
191 void LoopAnnotationConversion::convertLocation(
FusedLoc location) {
192 auto localScopeAttr =
193 dyn_cast_or_null<DILocalScopeAttr>(location.getMetadata());
196 auto *localScope = dyn_cast<llvm::DILocalScope>(
197 loopAnnotationTranslation.moduleTranslation.translateDebugInfo(
201 llvm::Metadata *loc =
202 loopAnnotationTranslation.moduleTranslation.translateLoc(location,
204 metadataNodes.push_back(loc);
207 llvm::MDNode *LoopAnnotationConversion::convert() {
209 auto dummy = llvm::MDNode::getTemporary(ctx, std::nullopt);
210 metadataNodes.push_back(dummy.get());
212 if (
FusedLoc startLoc = attr.getStartLoc())
213 convertLocation(startLoc);
215 if (
FusedLoc endLoc = attr.getEndLoc())
216 convertLocation(endLoc);
218 addUnitNode(
"llvm.loop.disable_nonforced", attr.getDisableNonforced());
219 addUnitNode(
"llvm.loop.mustprogress", attr.getMustProgress());
221 if (
BoolAttr isVectorized = attr.getIsVectorized())
222 addI32NodeWithVal(
"llvm.loop.isvectorized", isVectorized.getValue());
224 if (
auto options = attr.getVectorize())
226 if (
auto options = attr.getInterleave())
228 if (
auto options = attr.getUnroll())
230 if (
auto options = attr.getUnrollAndJam())
232 if (
auto options = attr.getLicm())
234 if (
auto options = attr.getDistribute())
236 if (
auto options = attr.getPipeline())
238 if (
auto options = attr.getPeeled())
240 if (
auto options = attr.getUnswitch())
244 if (!parallelAccessGroups.empty()) {
246 parallelAccess.push_back(
248 for (AccessGroupAttr accessGroupAttr : parallelAccessGroups)
249 parallelAccess.push_back(
250 loopAnnotationTranslation.getAccessGroup(accessGroupAttr));
256 loopMD->replaceOperandWith(0, loopMD);
267 llvm::MDNode *loopMD = lookupLoopMetadata(attr);
272 LoopAnnotationConversion(attr, op, *
this, this->llvmModule.getContext())
276 mapLoopMetadata(attr, loopMD);
282 auto [result, inserted] =
283 accessGroupMetadataMapping.insert({accessGroupAttr,
nullptr});
285 result->second = llvm::MDNode::getDistinct(llvmModule.getContext(), {});
286 return result->second;
291 ArrayAttr accessGroups = op.getAccessGroupsOrNull();
292 if (!accessGroups || accessGroups.empty())
296 for (AccessGroupAttr group : accessGroups.getAsRange<AccessGroupAttr>())
298 if (groupMDs.size() == 1)
299 return llvm::cast<llvm::MDNode>(groupMDs.front());
static llvm::ManagedStatic< PassManagerOptions > options
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
A helper class that converts LoopAnnotationAttrs and AccessGroupAttrs into corresponding llvm::MDNode...
llvm::MDNode * translateLoopAnnotation(LoopAnnotationAttr attr, Operation *op)
llvm::MDNode * getAccessGroups(AccessGroupOpInterface op)
Returns the LLVM metadata corresponding to the access group attribute referenced by the AccessGroupOp...
llvm::MDNode * getAccessGroup(AccessGroupAttr accessGroupAttr)
Returns the LLVM metadata corresponding to an mlir LLVM dialect access group attribute.
Operation is the basic unit of execution within MLIR.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...