MLIR  19.0.0git
LoopAnnotationTranslation.cpp
Go to the documentation of this file.
1 //===- LoopAnnotationTranslation.cpp - Loop annotation export -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "llvm/IR/DebugInfoMetadata.h"
11 
12 using namespace mlir;
13 using namespace mlir::LLVM;
14 using namespace mlir::LLVM::detail;
15 
16 namespace {
17 /// Helper class that keeps the state of one attribute to metadata conversion.
18 struct LoopAnnotationConversion {
19  LoopAnnotationConversion(LoopAnnotationAttr attr, Operation *op,
20  LoopAnnotationTranslation &loopAnnotationTranslation,
21  llvm::LLVMContext &ctx)
22  : attr(attr), op(op),
23  loopAnnotationTranslation(loopAnnotationTranslation), ctx(ctx) {}
24 
25  /// Converts this struct's loop annotation into a corresponding LLVMIR
26  /// metadata representation.
27  llvm::MDNode *convert();
28 
29  /// Conversion functions for different payload attribute kinds.
30  void addUnitNode(StringRef name);
31  void addUnitNode(StringRef name, BoolAttr attr);
32  void addI32NodeWithVal(StringRef name, uint32_t val);
33  void convertBoolNode(StringRef name, BoolAttr attr, bool negated = false);
34  void convertI32Node(StringRef name, IntegerAttr attr);
35  void convertFollowupNode(StringRef name, LoopAnnotationAttr attr);
36  void convertLocation(FusedLoc attr);
37 
38  /// Conversion functions for each for each loop annotation sub-attribute.
39  void convertLoopOptions(LoopVectorizeAttr options);
40  void convertLoopOptions(LoopInterleaveAttr options);
41  void convertLoopOptions(LoopUnrollAttr options);
42  void convertLoopOptions(LoopUnrollAndJamAttr options);
43  void convertLoopOptions(LoopLICMAttr options);
44  void convertLoopOptions(LoopDistributeAttr options);
45  void convertLoopOptions(LoopPipelineAttr options);
46  void convertLoopOptions(LoopPeeledAttr options);
47  void convertLoopOptions(LoopUnswitchAttr options);
48 
49  LoopAnnotationAttr attr;
50  Operation *op;
51  LoopAnnotationTranslation &loopAnnotationTranslation;
52  llvm::LLVMContext &ctx;
54 };
55 } // namespace
56 
57 void LoopAnnotationConversion::addUnitNode(StringRef name) {
58  metadataNodes.push_back(
59  llvm::MDNode::get(ctx, {llvm::MDString::get(ctx, name)}));
60 }
61 
62 void LoopAnnotationConversion::addUnitNode(StringRef name, BoolAttr attr) {
63  if (attr && attr.getValue())
64  addUnitNode(name);
65 }
66 
67 void LoopAnnotationConversion::addI32NodeWithVal(StringRef name, uint32_t val) {
68  llvm::Constant *cstValue = llvm::ConstantInt::get(
69  llvm::IntegerType::get(ctx, /*NumBits=*/32), val, /*isSigned=*/false);
70  metadataNodes.push_back(
71  llvm::MDNode::get(ctx, {llvm::MDString::get(ctx, name),
72  llvm::ConstantAsMetadata::get(cstValue)}));
73 }
74 
75 void LoopAnnotationConversion::convertBoolNode(StringRef name, BoolAttr attr,
76  bool negated) {
77  if (!attr)
78  return;
79  bool val = negated ^ attr.getValue();
80  llvm::Constant *cstValue = llvm::ConstantInt::getBool(ctx, val);
81  metadataNodes.push_back(
82  llvm::MDNode::get(ctx, {llvm::MDString::get(ctx, name),
83  llvm::ConstantAsMetadata::get(cstValue)}));
84 }
85 
86 void LoopAnnotationConversion::convertI32Node(StringRef name,
87  IntegerAttr attr) {
88  if (!attr)
89  return;
90  addI32NodeWithVal(name, attr.getInt());
91 }
92 
93 void LoopAnnotationConversion::convertFollowupNode(StringRef name,
94  LoopAnnotationAttr attr) {
95  if (!attr)
96  return;
97 
98  llvm::MDNode *node =
99  loopAnnotationTranslation.translateLoopAnnotation(attr, op);
100 
101  metadataNodes.push_back(
102  llvm::MDNode::get(ctx, {llvm::MDString::get(ctx, name), node}));
103 }
104 
105 void LoopAnnotationConversion::convertLoopOptions(LoopVectorizeAttr options) {
106  convertBoolNode("llvm.loop.vectorize.enable", options.getDisable(), true);
107  convertBoolNode("llvm.loop.vectorize.predicate.enable",
108  options.getPredicateEnable());
109  convertBoolNode("llvm.loop.vectorize.scalable.enable",
110  options.getScalableEnable());
111  convertI32Node("llvm.loop.vectorize.width", options.getWidth());
112  convertFollowupNode("llvm.loop.vectorize.followup_vectorized",
113  options.getFollowupVectorized());
114  convertFollowupNode("llvm.loop.vectorize.followup_epilogue",
115  options.getFollowupEpilogue());
116  convertFollowupNode("llvm.loop.vectorize.followup_all",
117  options.getFollowupAll());
118 }
119 
120 void LoopAnnotationConversion::convertLoopOptions(LoopInterleaveAttr options) {
121  convertI32Node("llvm.loop.interleave.count", options.getCount());
122 }
123 
124 void LoopAnnotationConversion::convertLoopOptions(LoopUnrollAttr options) {
125  if (auto disable = options.getDisable())
126  addUnitNode(disable.getValue() ? "llvm.loop.unroll.disable"
127  : "llvm.loop.unroll.enable");
128  convertI32Node("llvm.loop.unroll.count", options.getCount());
129  convertBoolNode("llvm.loop.unroll.runtime.disable",
130  options.getRuntimeDisable());
131  addUnitNode("llvm.loop.unroll.full", options.getFull());
132  convertFollowupNode("llvm.loop.unroll.followup_unrolled",
133  options.getFollowupUnrolled());
134  convertFollowupNode("llvm.loop.unroll.followup_remainder",
135  options.getFollowupRemainder());
136  convertFollowupNode("llvm.loop.unroll.followup_all",
137  options.getFollowupAll());
138 }
139 
140 void LoopAnnotationConversion::convertLoopOptions(
141  LoopUnrollAndJamAttr options) {
142  if (auto disable = options.getDisable())
143  addUnitNode(disable.getValue() ? "llvm.loop.unroll_and_jam.disable"
144  : "llvm.loop.unroll_and_jam.enable");
145  convertI32Node("llvm.loop.unroll_and_jam.count", options.getCount());
146  convertFollowupNode("llvm.loop.unroll_and_jam.followup_outer",
147  options.getFollowupOuter());
148  convertFollowupNode("llvm.loop.unroll_and_jam.followup_inner",
149  options.getFollowupInner());
150  convertFollowupNode("llvm.loop.unroll_and_jam.followup_remainder_outer",
151  options.getFollowupRemainderOuter());
152  convertFollowupNode("llvm.loop.unroll_and_jam.followup_remainder_inner",
153  options.getFollowupRemainderInner());
154  convertFollowupNode("llvm.loop.unroll_and_jam.followup_all",
155  options.getFollowupAll());
156 }
157 
158 void LoopAnnotationConversion::convertLoopOptions(LoopLICMAttr options) {
159  addUnitNode("llvm.licm.disable", options.getDisable());
160  addUnitNode("llvm.loop.licm_versioning.disable",
161  options.getVersioningDisable());
162 }
163 
164 void LoopAnnotationConversion::convertLoopOptions(LoopDistributeAttr options) {
165  convertBoolNode("llvm.loop.distribute.enable", options.getDisable(), true);
166  convertFollowupNode("llvm.loop.distribute.followup_coincident",
167  options.getFollowupCoincident());
168  convertFollowupNode("llvm.loop.distribute.followup_sequential",
169  options.getFollowupSequential());
170  convertFollowupNode("llvm.loop.distribute.followup_fallback",
171  options.getFollowupFallback());
172  convertFollowupNode("llvm.loop.distribute.followup_all",
173  options.getFollowupAll());
174 }
175 
176 void LoopAnnotationConversion::convertLoopOptions(LoopPipelineAttr options) {
177  convertBoolNode("llvm.loop.pipeline.disable", options.getDisable());
178  convertI32Node("llvm.loop.pipeline.initiationinterval",
179  options.getInitiationinterval());
180 }
181 
182 void LoopAnnotationConversion::convertLoopOptions(LoopPeeledAttr options) {
183  convertI32Node("llvm.loop.peeled.count", options.getCount());
184 }
185 
186 void LoopAnnotationConversion::convertLoopOptions(LoopUnswitchAttr options) {
187  addUnitNode("llvm.loop.unswitch.partial.disable",
188  options.getPartialDisable());
189 }
190 
191 void LoopAnnotationConversion::convertLocation(FusedLoc location) {
192  auto localScopeAttr =
193  dyn_cast_or_null<DILocalScopeAttr>(location.getMetadata());
194  if (!localScopeAttr)
195  return;
196  auto *localScope = dyn_cast<llvm::DILocalScope>(
197  loopAnnotationTranslation.moduleTranslation.translateDebugInfo(
198  localScopeAttr));
199  if (!localScope)
200  return;
201  llvm::Metadata *loc =
202  loopAnnotationTranslation.moduleTranslation.translateLoc(location,
203  localScope);
204  metadataNodes.push_back(loc);
205 }
206 
207 llvm::MDNode *LoopAnnotationConversion::convert() {
208  // Reserve operand 0 for loop id self reference.
209  auto dummy = llvm::MDNode::getTemporary(ctx, std::nullopt);
210  metadataNodes.push_back(dummy.get());
211 
212  if (FusedLoc startLoc = attr.getStartLoc())
213  convertLocation(startLoc);
214 
215  if (FusedLoc endLoc = attr.getEndLoc())
216  convertLocation(endLoc);
217 
218  addUnitNode("llvm.loop.disable_nonforced", attr.getDisableNonforced());
219  addUnitNode("llvm.loop.mustprogress", attr.getMustProgress());
220  // "isvectorized" is encoded as an i32 value.
221  if (BoolAttr isVectorized = attr.getIsVectorized())
222  addI32NodeWithVal("llvm.loop.isvectorized", isVectorized.getValue());
223 
224  if (auto options = attr.getVectorize())
225  convertLoopOptions(options);
226  if (auto options = attr.getInterleave())
227  convertLoopOptions(options);
228  if (auto options = attr.getUnroll())
229  convertLoopOptions(options);
230  if (auto options = attr.getUnrollAndJam())
231  convertLoopOptions(options);
232  if (auto options = attr.getLicm())
233  convertLoopOptions(options);
234  if (auto options = attr.getDistribute())
235  convertLoopOptions(options);
236  if (auto options = attr.getPipeline())
237  convertLoopOptions(options);
238  if (auto options = attr.getPeeled())
239  convertLoopOptions(options);
240  if (auto options = attr.getUnswitch())
241  convertLoopOptions(options);
242 
243  ArrayRef<AccessGroupAttr> parallelAccessGroups = attr.getParallelAccesses();
244  if (!parallelAccessGroups.empty()) {
245  SmallVector<llvm::Metadata *> parallelAccess;
246  parallelAccess.push_back(
247  llvm::MDString::get(ctx, "llvm.loop.parallel_accesses"));
248  for (AccessGroupAttr accessGroupAttr : parallelAccessGroups)
249  parallelAccess.push_back(
250  loopAnnotationTranslation.getAccessGroup(accessGroupAttr));
251  metadataNodes.push_back(llvm::MDNode::get(ctx, parallelAccess));
252  }
253 
254  // Create loop options and set the first operand to itself.
255  llvm::MDNode *loopMD = llvm::MDNode::get(ctx, metadataNodes);
256  loopMD->replaceOperandWith(0, loopMD);
257 
258  return loopMD;
259 }
260 
261 llvm::MDNode *
263  Operation *op) {
264  if (!attr)
265  return nullptr;
266 
267  llvm::MDNode *loopMD = lookupLoopMetadata(attr);
268  if (loopMD)
269  return loopMD;
270 
271  loopMD =
272  LoopAnnotationConversion(attr, op, *this, this->llvmModule.getContext())
273  .convert();
274  // Store a map from this Attribute to the LLVM metadata in case we
275  // encounter it again.
276  mapLoopMetadata(attr, loopMD);
277  return loopMD;
278 }
279 
280 llvm::MDNode *
281 LoopAnnotationTranslation::getAccessGroup(AccessGroupAttr accessGroupAttr) {
282  auto [result, inserted] =
283  accessGroupMetadataMapping.insert({accessGroupAttr, nullptr});
284  if (inserted)
285  result->second = llvm::MDNode::getDistinct(llvmModule.getContext(), {});
286  return result->second;
287 }
288 
289 llvm::MDNode *
290 LoopAnnotationTranslation::getAccessGroups(AccessGroupOpInterface op) {
291  ArrayAttr accessGroups = op.getAccessGroupsOrNull();
292  if (!accessGroups || accessGroups.empty())
293  return nullptr;
294 
296  for (AccessGroupAttr group : accessGroups.getAsRange<AccessGroupAttr>())
297  groupMDs.push_back(getAccessGroup(group));
298  if (groupMDs.size() == 1)
299  return llvm::cast<llvm::MDNode>(groupMDs.front());
300  return llvm::MDNode::get(llvmModule.getContext(), groupMDs);
301 }
static llvm::ManagedStatic< PassManagerOptions > options
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
A helper class that converts LoopAnnotationAttrs and AccessGroupAttrs into corresponding llvm::MDNode...
llvm::MDNode * translateLoopAnnotation(LoopAnnotationAttr attr, Operation *op)
llvm::MDNode * getAccessGroups(AccessGroupOpInterface op)
Returns the LLVM metadata corresponding to the access group attribute referenced by the AccessGroupOp...
llvm::MDNode * getAccessGroup(AccessGroupAttr accessGroupAttr)
Returns the LLVM metadata corresponding to an mlir LLVM dialect access group attribute.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...