MLIR  16.0.0git
Pattern.h
Go to the documentation of this file.
1 //===- Pattern.h - Pattern wrapper class ------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Pattern wrapper class to simplify using TableGen Record defining a MLIR
10 // Pattern.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_TABLEGEN_PATTERN_H_
15 #define MLIR_TABLEGEN_PATTERN_H_
16 
17 #include "mlir/Support/LLVM.h"
18 #include "mlir/TableGen/Argument.h"
19 #include "mlir/TableGen/Operator.h"
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/ADT/Hashing.h"
22 #include "llvm/ADT/StringMap.h"
23 #include "llvm/ADT/StringSet.h"
24 
25 #include <unordered_map>
26 
27 namespace llvm {
28 class DagInit;
29 class Init;
30 class Record;
31 } // namespace llvm
32 
33 namespace mlir {
34 namespace tblgen {
35 
36 // Mapping from TableGen Record to Operator wrapper object.
37 //
38 // We allocate each wrapper object in heap to make sure the pointer to it is
39 // valid throughout the lifetime of this map. This is important because this map
40 // is shared among multiple patterns to avoid creating the wrapper object for
41 // the same op again and again. But this map will continuously grow.
42 using RecordOperatorMap =
44 
45 class Pattern;
46 
47 // Wrapper class providing helper methods for accessing TableGen DAG leaves
48 // used inside Patterns. This class is lightweight and designed to be used like
49 // values.
50 //
51 // A TableGen DAG construct is of the syntax
52 // `(operator, arg0, arg1, ...)`.
53 //
54 // This class provides getters to retrieve `arg*` as tblgen:: wrapper objects
55 // for handy helper methods. It only works on `arg*`s that are not nested DAG
56 // constructs.
57 class DagLeaf {
58 public:
59  explicit DagLeaf(const llvm::Init *def) : def(def) {}
60 
61  // Returns true if this DAG leaf is not specified in the pattern. That is, it
62  // places no further constraints/transforms and just carries over the original
63  // value.
64  bool isUnspecified() const;
65 
66  // Returns true if this DAG leaf is matching an operand. That is, it specifies
67  // a type constraint.
68  bool isOperandMatcher() const;
69 
70  // Returns true if this DAG leaf is matching an attribute. That is, it
71  // specifies an attribute constraint.
72  bool isAttrMatcher() const;
73 
74  // Returns true if this DAG leaf is wrapping native code call.
75  bool isNativeCodeCall() const;
76 
77  // Returns true if this DAG leaf is specifying a constant attribute.
78  bool isConstantAttr() const;
79 
80  // Returns true if this DAG leaf is specifying an enum attribute case.
81  bool isEnumAttrCase() const;
82 
83  // Returns true if this DAG leaf is specifying a string attribute.
84  bool isStringAttr() const;
85 
86  // Returns this DAG leaf as a constraint. Asserts if fails.
87  Constraint getAsConstraint() const;
88 
89  // Returns this DAG leaf as an constant attribute. Asserts if fails.
90  ConstantAttr getAsConstantAttr() const;
91 
92  // Returns this DAG leaf as an enum attribute case.
93  // Precondition: isEnumAttrCase()
94  EnumAttrCase getAsEnumAttrCase() const;
95 
96  // Returns the matching condition template inside this DAG leaf. Assumes the
97  // leaf is an operand/attribute matcher and asserts otherwise.
98  std::string getConditionTemplate() const;
99 
100  // Returns the native code call template inside this DAG leaf.
101  // Precondition: isNativeCodeCall()
102  StringRef getNativeCodeTemplate() const;
103 
104  // Returns the number of values will be returned by the native helper
105  // function.
106  // Precondition: isNativeCodeCall()
107  int getNumReturnsOfNativeCode() const;
108 
109  // Returns the string associated with the leaf.
110  // Precondition: isStringAttr()
111  std::string getStringAttr() const;
112 
113  void print(raw_ostream &os) const;
114 
115 private:
117  const void *getAsOpaquePointer() const { return def; }
118 
119  // Returns true if the TableGen Init `def` in this DagLeaf is a DefInit and
120  // also a subclass of the given `superclass`.
121  bool isSubClassOf(StringRef superclass) const;
122 
123  const llvm::Init *def;
124 };
125 
126 // Wrapper class providing helper methods for accessing TableGen DAG constructs
127 // used inside Patterns. This class is lightweight and designed to be used like
128 // values.
129 //
130 // A TableGen DAG construct is of the syntax
131 // `(operator, arg0, arg1, ...)`.
132 //
133 // When used inside Patterns, `operator` corresponds to some dialect op, or
134 // a known list of verbs that defines special transformation actions. This
135 // `arg*` can be a nested DAG construct. This class provides getters to
136 // retrieve `operator` and `arg*` as tblgen:: wrapper objects for handy helper
137 // methods.
138 //
139 // A null DagNode contains a nullptr and converts to false implicitly.
140 class DagNode {
141 public:
142  explicit DagNode(const llvm::DagInit *node) : node(node) {}
143 
144  // Implicit bool converter that returns true if this DagNode is not a null
145  // DagNode.
146  operator bool() const { return node != nullptr; }
147 
148  // Returns the symbol bound to this DAG node.
149  StringRef getSymbol() const;
150 
151  // Returns the operator wrapper object corresponding to the dialect op matched
152  // by this DAG. The operator wrapper will be queried from the given `mapper`
153  // and created in it if not existing.
154  Operator &getDialectOp(RecordOperatorMap *mapper) const;
155 
156  // Returns the number of operations recursively involved in the DAG tree
157  // rooted from this node.
158  int getNumOps() const;
159 
160  // Returns the number of immediate arguments to this DAG node.
161  int getNumArgs() const;
162 
163  // Returns true if the `index`-th argument is a nested DAG construct.
164  bool isNestedDagArg(unsigned index) const;
165 
166  // Gets the `index`-th argument as a nested DAG construct if possible. Returns
167  // null DagNode otherwise.
168  DagNode getArgAsNestedDag(unsigned index) const;
169 
170  // Gets the `index`-th argument as a DAG leaf.
171  DagLeaf getArgAsLeaf(unsigned index) const;
172 
173  // Returns the specified name of the `index`-th argument.
174  StringRef getArgName(unsigned index) const;
175 
176  // Returns true if this DAG construct means to replace with an existing SSA
177  // value.
178  bool isReplaceWithValue() const;
179 
180  // Returns whether this DAG represents the location of an op creation.
181  bool isLocationDirective() const;
182 
183  // Returns whether this DAG is a return type specifier.
184  bool isReturnTypeDirective() const;
185 
186  // Returns true if this DAG node is wrapping native code call.
187  bool isNativeCodeCall() const;
188 
189  // Returns whether this DAG is an `either` specifier.
190  bool isEither() const;
191 
192  // Returns true if this DAG node is an operation.
193  bool isOperation() const;
194 
195  // Returns the native code call template inside this DAG node.
196  // Precondition: isNativeCodeCall()
197  StringRef getNativeCodeTemplate() const;
198 
199  // Returns the number of values will be returned by the native helper
200  // function.
201  // Precondition: isNativeCodeCall()
202  int getNumReturnsOfNativeCode() const;
203 
204  void print(raw_ostream &os) const;
205 
206 private:
207  friend class SymbolInfoMap;
209  const void *getAsOpaquePointer() const { return node; }
210 
211  const llvm::DagInit *node; // nullptr means null DagNode
212 };
213 
214 // A class for maintaining information for symbols bound in patterns and
215 // provides methods for resolving them according to specific use cases.
216 //
217 // Symbols can be bound to
218 //
219 // * Op arguments and op results in the source pattern and
220 // * Op results in result patterns.
221 //
222 // Symbols can be referenced in result patterns and additional constraints to
223 // the pattern.
224 //
225 // For example, in
226 //
227 // ```
228 // def : Pattern<
229 // (SrcOp:$results1 $arg0, %arg1),
230 // [(ResOp1:$results2), (ResOp2 $results2 (ResOp3 $arg0, $arg1))]>;
231 // ```
232 //
233 // `$argN` is bound to the `SrcOp`'s N-th argument. `$results1` is bound to
234 // `SrcOp`. `$results2` is bound to `ResOp1`. $result2 is referenced to build
235 // `ResOp2`. `$arg0` and `$arg1` are referenced to build `ResOp3`.
236 //
237 // If a symbol binds to a multi-result op and it does not have the `__N`
238 // suffix, the symbol is expanded to represent all results generated by the
239 // multi-result op. If the symbol has a `__N` suffix, then it will expand to
240 // only the N-th *static* result as declared in ODS, and that can still
241 // corresponds to multiple *dynamic* values if the N-th *static* result is
242 // variadic.
243 //
244 // This class keeps track of such symbols and resolves them into their bound
245 // values in a suitable way.
247 public:
248  explicit SymbolInfoMap(ArrayRef<SMLoc> loc) : loc(loc) {}
249 
250  // Class for information regarding a symbol.
251  class SymbolInfo {
252  public:
253  // Returns a type string of a variable.
254  std::string getVarTypeStr(StringRef name) const;
255 
256  // Returns a string for defining a variable named as `name` to store the
257  // value bound by this symbol.
258  std::string getVarDecl(StringRef name) const;
259 
260  // Returns a string for defining an argument which passes the reference of
261  // the variable.
262  std::string getArgDecl(StringRef name) const;
263 
264  // Returns a variable name for the symbol named as `name`.
265  std::string getVarName(StringRef name) const;
266 
267  private:
268  // Allow SymbolInfoMap to access private methods.
269  friend class SymbolInfoMap;
270 
271  // DagNode and DagLeaf are accessed by value which means it can't be used as
272  // identifier here. Use an opaque pointer type instead.
273  using DagAndConstant = std::pair<const void *, int>;
274 
275  // What kind of entity this symbol represents:
276  // * Attr: op attribute
277  // * Operand: op operand
278  // * Result: op result
279  // * Value: a value not attached to an op (e.g., from NativeCodeCall)
280  // * MultipleValues: a pack of values not attached to an op (e.g., from
281  // NativeCodeCall). This kind supports indexing.
282  enum class Kind : uint8_t { Attr, Operand, Result, Value, MultipleValues };
283 
284  // Creates a SymbolInfo instance. `dagAndConstant` is only used for `Attr`
285  // and `Operand` so should be llvm::None for `Result` and `Value` kind.
286  SymbolInfo(const Operator *op, Kind kind,
287  Optional<DagAndConstant> dagAndConstant);
288 
289  // Static methods for creating SymbolInfo.
290  static SymbolInfo getAttr(const Operator *op, int index) {
291  return SymbolInfo(op, Kind::Attr, DagAndConstant(nullptr, index));
292  }
293  static SymbolInfo getAttr() {
294  return SymbolInfo(nullptr, Kind::Attr, llvm::None);
295  }
296  static SymbolInfo getOperand(DagNode node, const Operator *op, int index) {
297  return SymbolInfo(op, Kind::Operand,
298  DagAndConstant(node.getAsOpaquePointer(), index));
299  }
300  static SymbolInfo getResult(const Operator *op) {
301  return SymbolInfo(op, Kind::Result, llvm::None);
302  }
303  static SymbolInfo getValue() {
304  return SymbolInfo(nullptr, Kind::Value, llvm::None);
305  }
306  static SymbolInfo getMultipleValues(int numValues) {
307  return SymbolInfo(nullptr, Kind::MultipleValues,
308  DagAndConstant(nullptr, numValues));
309  }
310 
311  // Returns the number of static values this symbol corresponds to.
312  // A static value is an operand/result declared in ODS. Normally a symbol
313  // only represents one static value, but symbols bound to op results can
314  // represent more than one if the op is a multi-result op.
315  int getStaticValueCount() const;
316 
317  // Returns a string containing the C++ expression for referencing this
318  // symbol as a value (if this symbol represents one static value) or a value
319  // range (if this symbol represents multiple static values). `name` is the
320  // name of the C++ variable that this symbol bounds to. `index` should only
321  // be used for indexing results. `fmt` is used to format each value.
322  // `separator` is used to separate values if this is a value range.
323  std::string getValueAndRangeUse(StringRef name, int index, const char *fmt,
324  const char *separator) const;
325 
326  // Returns a string containing the C++ expression for referencing this
327  // symbol as a value range regardless of how many static values this symbol
328  // represents. `name` is the name of the C++ variable that this symbol
329  // bounds to. `index` should only be used for indexing results. `fmt` is
330  // used to format each value. `separator` is used to separate values in the
331  // range.
332  std::string getAllRangeUse(StringRef name, int index, const char *fmt,
333  const char *separator) const;
334 
335  // The argument index (for `Attr` and `Operand` only)
336  int getArgIndex() const { return (*dagAndConstant).second; }
337 
338  // The number of values in the MultipleValue
339  int getSize() const { return (*dagAndConstant).second; }
340 
341  const Operator *op; // The op where the bound entity belongs
342  Kind kind; // The kind of the bound entity
343 
344  // The pair of DagNode pointer and constant value (for `Attr`, `Operand` and
345  // the size of MultipleValue symbol). Note that operands may be bound to the
346  // same symbol, use the DagNode and index to distinguish them. For `Attr`
347  // and MultipleValue, the Dag part will be nullptr.
348  Optional<DagAndConstant> dagAndConstant;
349 
350  // Alternative name for the symbol. It is used in case the name
351  // is not unique. Applicable for `Operand` only.
352  Optional<std::string> alternativeName;
353  };
354 
355  using BaseT = std::unordered_multimap<std::string, SymbolInfo>;
356 
357  // Iterators for accessing all symbols.
358  using iterator = BaseT::iterator;
359  iterator begin() { return symbolInfoMap.begin(); }
360  iterator end() { return symbolInfoMap.end(); }
361 
362  // Const iterators for accessing all symbols.
363  using const_iterator = BaseT::const_iterator;
364  const_iterator begin() const { return symbolInfoMap.begin(); }
365  const_iterator end() const { return symbolInfoMap.end(); }
366 
367  // Binds the given `symbol` to the `argIndex`-th argument to the given `op`.
368  // Returns false if `symbol` is already bound and symbols are not operands.
369  bool bindOpArgument(DagNode node, StringRef symbol, const Operator &op,
370  int argIndex);
371 
372  // Binds the given `symbol` to the results the given `op`. Returns false if
373  // `symbol` is already bound.
374  bool bindOpResult(StringRef symbol, const Operator &op);
375 
376  // A helper function for dispatching target value binding functions.
377  bool bindValues(StringRef symbol, int numValues = 1);
378 
379  // Registers the given `symbol` as bound to the Value(s). Returns false if
380  // `symbol` is already bound.
381  bool bindValue(StringRef symbol);
382 
383  // Registers the given `symbol` as bound to a MultipleValue. Return false if
384  // `symbol` is already bound.
385  bool bindMultipleValues(StringRef symbol, int numValues);
386 
387  // Registers the given `symbol` as bound to an attr. Returns false if `symbol`
388  // is already bound.
389  bool bindAttr(StringRef symbol);
390 
391  // Returns true if the given `symbol` is bound.
392  bool contains(StringRef symbol) const;
393 
394  // Returns an iterator to the information of the given symbol named as `key`.
395  const_iterator find(StringRef key) const;
396 
397  // Returns an iterator to the information of the given symbol named as `key`,
398  // with index `argIndex` for operator `op`.
399  const_iterator findBoundSymbol(StringRef key, DagNode node,
400  const Operator &op, int argIndex) const;
401  const_iterator findBoundSymbol(StringRef key,
402  const SymbolInfo &symbolInfo) const;
403 
404  // Returns the bounds of a range that includes all the elements which
405  // bind to the `key`.
406  std::pair<iterator, iterator> getRangeOfEqualElements(StringRef key);
407 
408  // Returns number of times symbol named as `key` was used.
409  int count(StringRef key) const;
410 
411  // Returns the number of static values of the given `symbol` corresponds to.
412  // A static value is an operand/result declared in ODS. Normally a symbol only
413  // represents one static value, but symbols bound to op results can represent
414  // more than one if the op is a multi-result op.
415  int getStaticValueCount(StringRef symbol) const;
416 
417  // Returns a string containing the C++ expression for referencing this
418  // symbol as a value (if this symbol represents one static value) or a value
419  // range (if this symbol represents multiple static values). `fmt` is used to
420  // format each value. `separator` is used to separate values if `symbol`
421  // represents a value range.
422  std::string getValueAndRangeUse(StringRef symbol, const char *fmt = "{0}",
423  const char *separator = ", ") const;
424 
425  // Returns a string containing the C++ expression for referencing this
426  // symbol as a value range regardless of how many static values this symbol
427  // represents. `fmt` is used to format each value. `separator` is used to
428  // separate values in the range.
429  std::string getAllRangeUse(StringRef symbol, const char *fmt = "{0}",
430  const char *separator = ", ") const;
431 
432  // Assign alternative unique names to Operands that have equal names.
433  void assignUniqueAlternativeNames();
434 
435  // Splits the given `symbol` into a value pack name and an index. Returns the
436  // value pack name and writes the index to `index` on success. Returns
437  // `symbol` itself if it does not contain an index.
438  //
439  // We can use `name__N` to access the `N`-th value in the value pack bound to
440  // `name`. `name` is typically the results of an multi-result op.
441  static StringRef getValuePackName(StringRef symbol, int *index = nullptr);
442 
443 private:
444  BaseT symbolInfoMap;
445 
446  // Pattern instantiation location. This is intended to be used as parameter
447  // to PrintFatalError() to report errors.
448  ArrayRef<SMLoc> loc;
449 };
450 
451 // Wrapper class providing helper methods for accessing MLIR Pattern defined
452 // in TableGen. This class should closely reflect what is defined as class
453 // `Pattern` in TableGen. This class contains maps so it is not intended to be
454 // used as values.
455 class Pattern {
456 public:
457  explicit Pattern(const llvm::Record *def, RecordOperatorMap *mapper);
458 
459  // Returns the source pattern to match.
460  DagNode getSourcePattern() const;
461 
462  // Returns the number of result patterns generated by applying this rewrite
463  // rule.
464  int getNumResultPatterns() const;
465 
466  // Returns the DAG tree root node of the `index`-th result pattern.
467  DagNode getResultPattern(unsigned index) const;
468 
469  // Collects all symbols bound in the source pattern into `infoMap`.
470  void collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap);
471 
472  // Collects all symbols bound in result patterns into `infoMap`.
473  void collectResultPatternBoundSymbols(SymbolInfoMap &infoMap);
474 
475  // Returns the op that the root node of the source pattern matches.
476  const Operator &getSourceRootOp();
477 
478  // Returns the operator wrapper object corresponding to the given `node`'s DAG
479  // operator.
480  Operator &getDialectOp(DagNode node);
481 
482  // Returns the constraints.
483  std::vector<AppliedConstraint> getConstraints() const;
484 
485  // Returns the benefit score of the pattern.
486  int getBenefit() const;
487 
488  using IdentifierLine = std::pair<StringRef, unsigned>;
489 
490  // Returns the file location of the pattern (buffer identifier + line number
491  // pair).
492  std::vector<IdentifierLine> getLocation() const;
493 
494  // Recursively collects all bound symbols inside the DAG tree rooted
495  // at `tree` and updates the given `infoMap`.
496  void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
497  bool isSrcPattern);
498 
499 private:
500  // Helper function to verify variable binding.
501  void verifyBind(bool result, StringRef symbolName);
502 
503  // The TableGen definition of this pattern.
504  const llvm::Record &def;
505 
506  // All operators.
507  // TODO: we need a proper context manager, like MLIRContext, for managing the
508  // lifetime of shared entities.
509  RecordOperatorMap *recordOpMap;
510 };
511 
512 } // namespace tblgen
513 } // namespace mlir
514 
515 namespace llvm {
516 template <>
517 struct DenseMapInfo<mlir::tblgen::DagNode> {
519  return mlir::tblgen::DagNode(
521  }
523  return mlir::tblgen::DagNode(
525  }
526  static unsigned getHashValue(mlir::tblgen::DagNode node) {
527  return llvm::hash_value(node.getAsOpaquePointer());
528  }
530  return lhs.node == rhs.node;
531  }
532 };
533 
534 template <>
535 struct DenseMapInfo<mlir::tblgen::DagLeaf> {
537  return mlir::tblgen::DagLeaf(
539  }
541  return mlir::tblgen::DagLeaf(
543  }
544  static unsigned getHashValue(mlir::tblgen::DagLeaf leaf) {
545  return llvm::hash_value(leaf.getAsOpaquePointer());
546  }
548  return lhs.def == rhs.def;
549  }
550 };
551 } // namespace llvm
552 
553 #endif // MLIR_TABLEGEN_PATTERN_H_
Kind
Tensor expression kind.
Definition: Merger.h:40
Include the generated interface declarations.
static unsigned getHashValue(mlir::tblgen::DagNode node)
Definition: Pattern.h:526
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition: CallGraph.h:229
static mlir::tblgen::DagNode getTombstoneKey()
Definition: Pattern.h:522
llvm::hash_code hash_value(const MPInt &x)
Redeclarations of friend declaration above to make it discoverable by lookups.
Definition: MPInt.cpp:15
static bool isEqual(mlir::tblgen::DagNode lhs, mlir::tblgen::DagNode rhs)
Definition: Pattern.h:529
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static unsigned getHashValue(mlir::tblgen::DagLeaf leaf)
Definition: Pattern.h:544
BaseT::const_iterator const_iterator
Definition: Pattern.h:363
static mlir::tblgen::DagLeaf getEmptyKey()
Definition: Pattern.h:536
DenseMap< const llvm::Record *, std::unique_ptr< Operator > > RecordOperatorMap
Definition: Pattern.h:43
SymbolInfoMap(ArrayRef< SMLoc > loc)
Definition: Pattern.h:248
static bool contains(SMRange range, SMLoc loc)
Returns true if the given range contains the given source location.
Definition: MLIRServer.cpp:100
static bool isEqual(mlir::tblgen::DagLeaf lhs, mlir::tblgen::DagLeaf rhs)
Definition: Pattern.h:547
static mlir::tblgen::DagLeaf getTombstoneKey()
Definition: Pattern.h:540
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
DagLeaf(const llvm::Init *def)
Definition: Pattern.h:59
std::pair< StringRef, unsigned > IdentifierLine
Definition: Pattern.h:488
const_iterator begin() const
Definition: Pattern.h:364
BaseT::iterator iterator
Definition: Pattern.h:358
static mlir::tblgen::DagNode getEmptyKey()
Definition: Pattern.h:518
DagNode(const llvm::DagInit *node)
Definition: Pattern.h:142
const_iterator end() const
Definition: Pattern.h:365