10#include "llvm/IR/DebugInfoMetadata.h"
18struct LoopAnnotationConversion {
19 LoopAnnotationConversion(LoopAnnotationAttr attr, Operation *op,
20 LoopAnnotationTranslation &loopAnnotationTranslation,
21 llvm::LLVMContext &ctx)
23 loopAnnotationTranslation(loopAnnotationTranslation), ctx(ctx) {}
27 llvm::MDNode *convert();
30 void addUnitNode(StringRef name);
31 void addUnitNode(StringRef name, BoolAttr attr);
32 void addI32NodeWithVal(StringRef name, uint32_t val);
33 void convertBoolNode(StringRef name, BoolAttr attr,
bool negated =
false);
34 void convertI32Node(StringRef name, IntegerAttr attr);
35 void convertFollowupNode(StringRef name, LoopAnnotationAttr attr);
36 void convertLocation(FusedLoc attr);
39 void convertLoopOptions(LoopVectorizeAttr
options);
40 void convertLoopOptions(LoopInterleaveAttr
options);
41 void convertLoopOptions(LoopUnrollAttr
options);
42 void convertLoopOptions(LoopUnrollAndJamAttr
options);
43 void convertLoopOptions(LoopLICMAttr
options);
44 void convertLoopOptions(LoopDistributeAttr
options);
45 void convertLoopOptions(LoopPipelineAttr
options);
46 void convertLoopOptions(LoopPeeledAttr
options);
47 void convertLoopOptions(LoopUnswitchAttr
options);
49 LoopAnnotationAttr attr;
51 LoopAnnotationTranslation &loopAnnotationTranslation;
52 llvm::LLVMContext &ctx;
53 llvm::SmallVector<llvm::Metadata *> metadataNodes;
57void LoopAnnotationConversion::addUnitNode(StringRef name) {
58 metadataNodes.push_back(
59 llvm::MDNode::get(ctx, {llvm::MDString::get(ctx, name)}));
62void LoopAnnotationConversion::addUnitNode(StringRef name, BoolAttr attr) {
67void LoopAnnotationConversion::addI32NodeWithVal(StringRef name, uint32_t val) {
68 llvm::Constant *cstValue = llvm::ConstantInt::get(
69 llvm::IntegerType::get(ctx, 32), val,
false);
70 metadataNodes.push_back(
71 llvm::MDNode::get(ctx, {llvm::MDString::get(ctx, name),
72 llvm::ConstantAsMetadata::get(cstValue)}));
75void LoopAnnotationConversion::convertBoolNode(StringRef name, BoolAttr attr,
79 bool val = negated ^ attr.
getValue();
80 llvm::Constant *cstValue = llvm::ConstantInt::getBool(ctx, val);
81 metadataNodes.push_back(
82 llvm::MDNode::get(ctx, {llvm::MDString::get(ctx, name),
83 llvm::ConstantAsMetadata::get(cstValue)}));
86void LoopAnnotationConversion::convertI32Node(StringRef name,
90 addI32NodeWithVal(name, attr.getInt());
93void LoopAnnotationConversion::convertFollowupNode(StringRef name,
94 LoopAnnotationAttr attr) {
101 metadataNodes.push_back(
102 llvm::MDNode::get(ctx, {llvm::MDString::get(ctx, name), node}));
105void LoopAnnotationConversion::convertLoopOptions(LoopVectorizeAttr
options) {
106 convertBoolNode(
"llvm.loop.vectorize.enable",
options.getDisable(),
true);
107 convertBoolNode(
"llvm.loop.vectorize.predicate.enable",
109 convertBoolNode(
"llvm.loop.vectorize.scalable.enable",
111 convertI32Node(
"llvm.loop.vectorize.width",
options.getWidth());
112 convertFollowupNode(
"llvm.loop.vectorize.followup_vectorized",
113 options.getFollowupVectorized());
114 convertFollowupNode(
"llvm.loop.vectorize.followup_epilogue",
115 options.getFollowupEpilogue());
116 convertFollowupNode(
"llvm.loop.vectorize.followup_all",
120void LoopAnnotationConversion::convertLoopOptions(LoopInterleaveAttr
options) {
121 convertI32Node(
"llvm.loop.interleave.count",
options.getCount());
124void LoopAnnotationConversion::convertLoopOptions(LoopUnrollAttr
options) {
125 if (
auto disable =
options.getDisable())
126 addUnitNode(disable.getValue() ?
"llvm.loop.unroll.disable"
127 :
"llvm.loop.unroll.enable");
128 convertI32Node(
"llvm.loop.unroll.count",
options.getCount());
129 convertBoolNode(
"llvm.loop.unroll.runtime.disable",
131 addUnitNode(
"llvm.loop.unroll.full",
options.getFull());
132 convertFollowupNode(
"llvm.loop.unroll.followup_unrolled",
133 options.getFollowupUnrolled());
134 convertFollowupNode(
"llvm.loop.unroll.followup_remainder",
135 options.getFollowupRemainder());
136 convertFollowupNode(
"llvm.loop.unroll.followup_all",
140void LoopAnnotationConversion::convertLoopOptions(
141 LoopUnrollAndJamAttr
options) {
142 if (
auto disable =
options.getDisable())
143 addUnitNode(disable.getValue() ?
"llvm.loop.unroll_and_jam.disable"
144 :
"llvm.loop.unroll_and_jam.enable");
145 convertI32Node(
"llvm.loop.unroll_and_jam.count",
options.getCount());
146 convertFollowupNode(
"llvm.loop.unroll_and_jam.followup_outer",
148 convertFollowupNode(
"llvm.loop.unroll_and_jam.followup_inner",
150 convertFollowupNode(
"llvm.loop.unroll_and_jam.followup_remainder_outer",
151 options.getFollowupRemainderOuter());
152 convertFollowupNode(
"llvm.loop.unroll_and_jam.followup_remainder_inner",
153 options.getFollowupRemainderInner());
154 convertFollowupNode(
"llvm.loop.unroll_and_jam.followup_all",
158void LoopAnnotationConversion::convertLoopOptions(LoopLICMAttr
options) {
159 addUnitNode(
"llvm.licm.disable",
options.getDisable());
160 addUnitNode(
"llvm.loop.licm_versioning.disable",
161 options.getVersioningDisable());
164void LoopAnnotationConversion::convertLoopOptions(LoopDistributeAttr
options) {
165 convertBoolNode(
"llvm.loop.distribute.enable",
options.getDisable(),
true);
166 convertFollowupNode(
"llvm.loop.distribute.followup_coincident",
167 options.getFollowupCoincident());
168 convertFollowupNode(
"llvm.loop.distribute.followup_sequential",
169 options.getFollowupSequential());
170 convertFollowupNode(
"llvm.loop.distribute.followup_fallback",
171 options.getFollowupFallback());
172 convertFollowupNode(
"llvm.loop.distribute.followup_all",
176void LoopAnnotationConversion::convertLoopOptions(LoopPipelineAttr
options) {
177 convertBoolNode(
"llvm.loop.pipeline.disable",
options.getDisable());
178 convertI32Node(
"llvm.loop.pipeline.initiationinterval",
179 options.getInitiationinterval());
182void LoopAnnotationConversion::convertLoopOptions(LoopPeeledAttr
options) {
183 convertI32Node(
"llvm.loop.peeled.count",
options.getCount());
186void LoopAnnotationConversion::convertLoopOptions(LoopUnswitchAttr
options) {
187 addUnitNode(
"llvm.loop.unswitch.partial.disable",
191void LoopAnnotationConversion::convertLocation(FusedLoc location) {
192 auto localScopeAttr =
193 dyn_cast_or_null<DILocalScopeAttr>(location.getMetadata());
196 auto *localScope = dyn_cast<llvm::DILocalScope>(
201 llvm::Metadata *loc =
204 metadataNodes.push_back(loc);
207llvm::MDNode *LoopAnnotationConversion::convert() {
209 auto dummy = llvm::MDNode::getTemporary(ctx, {});
210 metadataNodes.push_back(dummy.get());
212 if (FusedLoc startLoc = attr.getStartLoc())
213 convertLocation(startLoc);
215 if (FusedLoc endLoc = attr.getEndLoc())
216 convertLocation(endLoc);
218 addUnitNode(
"llvm.loop.disable_nonforced", attr.getDisableNonforced());
219 addUnitNode(
"llvm.loop.mustprogress", attr.getMustProgress());
221 if (BoolAttr isVectorized = attr.getIsVectorized())
222 addI32NodeWithVal(
"llvm.loop.isvectorized", isVectorized.getValue());
224 if (
auto options = attr.getVectorize())
226 if (
auto options = attr.getInterleave())
228 if (
auto options = attr.getUnroll())
230 if (
auto options = attr.getUnrollAndJam())
232 if (
auto options = attr.getLicm())
234 if (
auto options = attr.getDistribute())
236 if (
auto options = attr.getPipeline())
238 if (
auto options = attr.getPeeled())
240 if (
auto options = attr.getUnswitch())
243 ArrayRef<AccessGroupAttr> parallelAccessGroups = attr.getParallelAccesses();
244 if (!parallelAccessGroups.empty()) {
245 SmallVector<llvm::Metadata *> parallelAccess;
246 parallelAccess.push_back(
247 llvm::MDString::get(ctx,
"llvm.loop.parallel_accesses"));
248 for (AccessGroupAttr accessGroupAttr : parallelAccessGroups)
249 parallelAccess.push_back(
251 metadataNodes.push_back(llvm::MDNode::get(ctx, parallelAccess));
255 llvm::MDNode *loopMD = llvm::MDNode::get(ctx, metadataNodes);
256 loopMD->replaceOperandWith(0, loopMD);
267 llvm::MDNode *loopMD = lookupLoopMetadata(attr);
272 LoopAnnotationConversion(attr, op, *
this, this->llvmModule.
getContext())
276 mapLoopMetadata(attr, loopMD);
283 accessGroupMetadataMapping.try_emplace(accessGroupAttr);
285 result->second = llvm::MDNode::getDistinct(llvmModule.getContext(), {});
291 ArrayAttr accessGroups = op.getAccessGroupsOrNull();
292 if (!accessGroups || accessGroups.empty())
296 for (AccessGroupAttr group : accessGroups.getAsRange<AccessGroupAttr>())
298 if (groupMDs.size() == 1)
299 return llvm::cast<llvm::MDNode>(groupMDs.front());
300 return llvm::MDNode::get(llvmModule.getContext(), groupMDs);
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
static llvm::ManagedStatic< PassManagerOptions > options
bool getValue() const
Return the boolean value of this attribute.
llvm::DILocation * translateLoc(Location loc, llvm::DILocalScope *scope)
Translates the given location.
llvm::Metadata * translateDebugInfo(LLVM::DINodeAttr attr)
Translates the given LLVM debug info metadata.
llvm::MDNode * translateLoopAnnotation(LoopAnnotationAttr attr, Operation *op)
llvm::MDNode * getAccessGroups(AccessGroupOpInterface op)
Returns the LLVM metadata corresponding to the access group attribute referenced by the AccessGroupOp...
ModuleTranslation & moduleTranslation
The ModuleTranslation owning this instance.
llvm::MDNode * getAccessGroup(AccessGroupAttr accessGroupAttr)
Returns the LLVM metadata corresponding to an mlir LLVM dialect access group attribute.
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext()
Return the context this operation is associated with.
Include the generated interface declarations.