MLIR  19.0.0git
CodeGenHelpers.cpp
Go to the documentation of this file.
1 //===- CodeGenHelpers.cpp - MLIR op definitions generator ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpDefinitionsGen uses the description of operations to generate C++
10 // definitions for ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 #include "mlir/TableGen/Operator.h"
16 #include "mlir/TableGen/Pattern.h"
17 #include "llvm/Support/FormatVariadic.h"
18 #include "llvm/Support/Path.h"
19 #include "llvm/TableGen/Record.h"
20 
21 using namespace llvm;
22 using namespace mlir;
23 using namespace mlir::tblgen;
24 
25 /// Generate a unique label based on the current file name to prevent name
26 /// collisions if multiple generated files are included at once.
27 static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records) {
28  // Use the input file name when generating a unique name.
29  std::string inputFilename = records.getInputFilename();
30 
31  // Drop all but the base filename.
32  StringRef nameRef = llvm::sys::path::filename(inputFilename);
33  nameRef.consume_back(".td");
34 
35  // Sanitize any invalid characters.
36  std::string uniqueName;
37  for (char c : nameRef) {
38  if (llvm::isAlnum(c) || c == '_')
39  uniqueName.push_back(c);
40  else
41  uniqueName.append(llvm::utohexstr((unsigned char)c));
42  }
43  return uniqueName;
44 }
45 
46 StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
47  raw_ostream &os, const llvm::RecordKeeper &records)
48  : os(os), uniqueOutputLabel(getUniqueOutputLabel(records)) {}
49 
51  ArrayRef<llvm::Record *> opDefs, bool emitDecl) {
52  collectOpConstraints(opDefs);
53  if (emitDecl)
54  return;
55 
56  NamespaceEmitter namespaceEmitter(os, Operator(*opDefs[0]).getCppNamespace());
57  emitTypeConstraints();
58  emitAttrConstraints();
59  emitSuccessorConstraints();
60  emitRegionConstraints();
61 }
62 
63 void StaticVerifierFunctionEmitter::emitPatternConstraints(
64  const llvm::ArrayRef<DagLeaf> constraints) {
65  collectPatternConstraints(constraints);
67 }
68 
69 //===----------------------------------------------------------------------===//
70 // Constraint Getters
71 
73  const Constraint &constraint) const {
74  const auto *it = typeConstraints.find(constraint);
75  assert(it != typeConstraints.end() && "expected to find a type constraint");
76  return it->second;
77 }
78 
79 // Find a uniqued attribute constraint. Since not all attribute constraints can
80 // be uniqued, return std::nullopt if one was not found.
82  const Constraint &constraint) const {
83  const auto *it = attrConstraints.find(constraint);
84  return it == attrConstraints.end() ? std::optional<StringRef>()
85  : StringRef(it->second);
86 }
87 
89  const Constraint &constraint) const {
90  const auto *it = successorConstraints.find(constraint);
91  assert(it != successorConstraints.end() &&
92  "expected to find a sucessor constraint");
93  return it->second;
94 }
95 
97  const Constraint &constraint) const {
98  const auto *it = regionConstraints.find(constraint);
99  assert(it != regionConstraints.end() &&
100  "expected to find a region constraint");
101  return it->second;
102 }
103 
104 //===----------------------------------------------------------------------===//
105 // Constraint Emission
106 
107 /// Code templates for emitting type, attribute, successor, and region
108 /// constraints. Each of these templates require the following arguments:
109 ///
110 /// {0}: The unique constraint name.
111 /// {1}: The constraint code.
112 /// {2}: The constraint description.
113 
114 /// Code for a type constraint. These may be called on the type of either
115 /// operands or results.
116 static const char *const typeConstraintCode = R"(
117 static ::mlir::LogicalResult {0}(
118  ::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef valueKind,
119  unsigned valueIndex) {
120  if (!({1})) {
121  return op->emitOpError(valueKind) << " #" << valueIndex
122  << " must be {2}, but got " << type;
123  }
124  return ::mlir::success();
125 }
126 )";
127 
128 /// Code for an attribute constraint. These may be called from ops only.
129 /// Attribute constraints cannot reference anything other than `$_self` and
130 /// `$_op`.
131 ///
132 /// TODO: Unique constraints for adaptors. However, most Adaptor::verify
133 /// functions are stripped anyways.
134 static const char *const attrConstraintCode = R"(
135 static ::mlir::LogicalResult {0}(
136  ::mlir::Attribute attr, ::llvm::StringRef attrName, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) {{
137  if (attr && !({1}))
138  return emitError() << "attribute '" << attrName
139  << "' failed to satisfy constraint: {2}";
140  return ::mlir::success();
141 }
142 static ::mlir::LogicalResult {0}(
143  ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) {{
144  return {0}(attr, attrName, [op]() {{
145  return op->emitOpError();
146  });
147 }
148 )";
149 
150 /// Code for a successor constraint.
151 static const char *const successorConstraintCode = R"(
152 static ::mlir::LogicalResult {0}(
153  ::mlir::Operation *op, ::mlir::Block *successor,
154  ::llvm::StringRef successorName, unsigned successorIndex) {
155  if (!({1})) {
156  return op->emitOpError("successor #") << successorIndex << " ('"
157  << successorName << ")' failed to verify constraint: {2}";
158  }
159  return ::mlir::success();
160 }
161 )";
162 
163 /// Code for a region constraint. Callers will need to pass in the region's name
164 /// for emitting an error message.
165 static const char *const regionConstraintCode = R"(
166 static ::mlir::LogicalResult {0}(
167  ::mlir::Operation *op, ::mlir::Region &region, ::llvm::StringRef regionName,
168  unsigned regionIndex) {
169  if (!({1})) {
170  return op->emitOpError("region #") << regionIndex
171  << (regionName.empty() ? " " : " ('" + regionName + "') ")
172  << "failed to verify constraint: {2}";
173  }
174  return ::mlir::success();
175 }
176 )";
177 
178 /// Code for a pattern type or attribute constraint.
179 ///
180 /// {3}: "Type type" or "Attribute attr".
181 static const char *const patternAttrOrTypeConstraintCode = R"(
182 static ::mlir::LogicalResult {0}(
183  ::mlir::PatternRewriter &rewriter, ::mlir::Operation *op, ::mlir::{3},
184  ::llvm::StringRef failureStr) {
185  if (!({1})) {
186  return rewriter.notifyMatchFailure(op, [&](::mlir::Diagnostic &diag) {
187  diag << failureStr << ": {2}";
188  });
189  }
190  return ::mlir::success();
191 }
192 )";
193 
194 void StaticVerifierFunctionEmitter::emitConstraints(
195  const ConstraintMap &constraints, StringRef selfName,
196  const char *const codeTemplate) {
197  FmtContext ctx;
198  ctx.addSubst("_op", "*op").withSelf(selfName);
199  for (auto &it : constraints) {
200  os << formatv(codeTemplate, it.second,
201  tgfmt(it.first.getConditionTemplate(), &ctx),
202  escapeString(it.first.getSummary()));
203  }
204 }
205 
206 void StaticVerifierFunctionEmitter::emitTypeConstraints() {
207  emitConstraints(typeConstraints, "type", typeConstraintCode);
208 }
209 
210 void StaticVerifierFunctionEmitter::emitAttrConstraints() {
211  emitConstraints(attrConstraints, "attr", attrConstraintCode);
212 }
213 
214 void StaticVerifierFunctionEmitter::emitSuccessorConstraints() {
215  emitConstraints(successorConstraints, "successor", successorConstraintCode);
216 }
217 
218 void StaticVerifierFunctionEmitter::emitRegionConstraints() {
219  emitConstraints(regionConstraints, "region", regionConstraintCode);
220 }
221 
222 void StaticVerifierFunctionEmitter::emitPatternConstraints() {
223  FmtContext ctx;
224  ctx.addSubst("_op", "*op").withBuilder("rewriter").withSelf("type");
225  for (auto &it : typeConstraints) {
226  os << formatv(patternAttrOrTypeConstraintCode, it.second,
227  tgfmt(it.first.getConditionTemplate(), &ctx),
228  escapeString(it.first.getSummary()), "Type type");
229  }
230  ctx.withSelf("attr");
231  for (auto &it : attrConstraints) {
232  os << formatv(patternAttrOrTypeConstraintCode, it.second,
233  tgfmt(it.first.getConditionTemplate(), &ctx),
234  escapeString(it.first.getSummary()), "Attribute attr");
235  }
236 }
237 
238 //===----------------------------------------------------------------------===//
239 // Constraint Uniquing
240 
241 /// An attribute constraint that references anything other than itself and the
242 /// current op cannot be generically extracted into a function. Most
243 /// prohibitive are operands and results, which require calls to
244 /// `getODSOperands` or `getODSResults`. Attribute references are tricky too
245 /// because ops use cached identifiers.
247  FmtContext ctx;
248  auto test = tgfmt(attr.getConditionTemplate(),
249  &ctx.withSelf("attr").addSubst("_op", "*op"))
250  .str();
251  return !StringRef(test).contains("<no-subst-found>");
252 }
253 
254 std::string StaticVerifierFunctionEmitter::getUniqueName(StringRef kind,
255  unsigned index) {
256  return ("__mlir_ods_local_" + kind + "_constraint_" + uniqueOutputLabel +
257  Twine(index))
258  .str();
259 }
260 
261 void StaticVerifierFunctionEmitter::collectConstraint(ConstraintMap &map,
262  StringRef kind,
263  Constraint constraint) {
264  auto *it = map.find(constraint);
265  if (it == map.end())
266  map.insert({constraint, getUniqueName(kind, map.size())});
267 }
268 
269 void StaticVerifierFunctionEmitter::collectOpConstraints(
270  ArrayRef<Record *> opDefs) {
271  const auto collectTypeConstraints = [&](Operator::const_value_range values) {
272  for (const NamedTypeConstraint &value : values)
273  if (value.hasPredicate())
274  collectConstraint(typeConstraints, "type", value.constraint);
275  };
276 
277  for (Record *def : opDefs) {
278  Operator op(*def);
279  /// Collect type constraints.
280  collectTypeConstraints(op.getOperands());
281  collectTypeConstraints(op.getResults());
282  /// Collect attribute constraints.
283  for (const NamedAttribute &namedAttr : op.getAttributes()) {
284  if (!namedAttr.attr.getPredicate().isNull() &&
285  !namedAttr.attr.isDerivedAttr() &&
286  canUniqueAttrConstraint(namedAttr.attr))
287  collectConstraint(attrConstraints, "attr", namedAttr.attr);
288  }
289  /// Collect successor constraints.
290  for (const NamedSuccessor &successor : op.getSuccessors()) {
291  if (!successor.constraint.getPredicate().isNull()) {
292  collectConstraint(successorConstraints, "successor",
293  successor.constraint);
294  }
295  }
296  /// Collect region constraints.
297  for (const NamedRegion &region : op.getRegions())
298  if (!region.constraint.getPredicate().isNull())
299  collectConstraint(regionConstraints, "region", region.constraint);
300  }
301 }
302 
303 void StaticVerifierFunctionEmitter::collectPatternConstraints(
304  const llvm::ArrayRef<DagLeaf> constraints) {
305  for (auto &leaf : constraints) {
306  assert(leaf.isOperandMatcher() || leaf.isAttrMatcher());
307  collectConstraint(
308  leaf.isOperandMatcher() ? typeConstraints : attrConstraints,
309  leaf.isOperandMatcher() ? "type" : "attr", leaf.getAsConstraint());
310  }
311 }
312 
313 //===----------------------------------------------------------------------===//
314 // Public Utility Functions
315 //===----------------------------------------------------------------------===//
316 
317 std::string mlir::tblgen::escapeString(StringRef value) {
318  std::string ret;
319  llvm::raw_string_ostream os(ret);
320  os.write_escaped(value);
321  return os.str();
322 }
static const char *const successorConstraintCode
Code for a successor constraint.
static const char *const regionConstraintCode
Code for a region constraint.
static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records)
Generate a unique label based on the current file name to prevent name collisions if multiple generat...
static const char *const typeConstraintCode
Code templates for emitting type, attribute, successor, and region constraints.
static const char *const patternAttrOrTypeConstraintCode
Code for a pattern type or attribute constraint.
static bool canUniqueAttrConstraint(Attribute attr)
An attribute constraint that references anything other than itself and the current op cannot be gener...
static const char *const attrConstraintCode
Code for an attribute constraint.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:202
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
SuccessorRange getSuccessors()
Definition: Operation.h:699
result_range getResults()
Definition: Operation.h:410
Format context containing substitutions for special placeholders.
Definition: Format.h:40
FmtContext & withBuilder(Twine subst)
Definition: Format.cpp:36
FmtContext & withSelf(Twine subst)
Definition: Format.cpp:41
FmtContext & addSubst(StringRef placeholder, const Twine &subst)
Definition: Format.cpp:31
Wrapper class that contains a MLIR op's information (e.g., operands, attributes) defined in TableGen ...
Definition: Operator.h:77
StringRef getRegionConstraintFn(const Constraint &constraint) const
Get the name of the static function used for the given region constraint.
void emitPatternConstraints(const ArrayRef< DagLeaf > constraints)
Unique all compatible type and attribute constraints from a pattern file and emit them at the top of ...
void emitOpConstraints(ArrayRef< llvm::Record * > opDefs, bool emitDecl)
Collect and unique all compatible type, attribute, successor, and region constraints from the operati...
std::optional< StringRef > getAttrConstraintFn(const Constraint &constraint) const
Get the name of the static function used for the given attribute constraint.
StringRef getTypeConstraintFn(const Constraint &constraint) const
Get the name of the static function used for the given type constraint.
StringRef getSuccessorConstraintFn(const Constraint &constraint) const
Get the name of the static function used for the given successor constraint.
Include the generated interface declarations.
Definition: CallGraph.h:229
auto tgfmt(StringRef fmt, const FmtContext *ctx, Ts &&...vals) -> FmtObject< decltype(std::make_tuple(llvm::support::detail::build_format_adapter(std::forward< Ts >(vals))...))>
Formats text by substituting placeholders in format string with replacement parameters.
Definition: Format.h:262
std::string escapeString(StringRef value)
Escape a string using C++ encoding. E.g. foo"bar -> foo\x22bar.
Include the generated interface declarations.