MLIR  21.0.0git
Pattern.h
Go to the documentation of this file.
1 //===- Pattern.h - Pattern wrapper class ------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Pattern wrapper class to simplify using TableGen Record defining a MLIR
10 // Pattern.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_TABLEGEN_PATTERN_H_
15 #define MLIR_TABLEGEN_PATTERN_H_
16 
17 #include "mlir/Support/LLVM.h"
18 #include "mlir/TableGen/Argument.h"
19 #include "mlir/TableGen/EnumInfo.h"
20 #include "mlir/TableGen/Operator.h"
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/ADT/Hashing.h"
23 #include "llvm/ADT/StringMap.h"
24 #include "llvm/ADT/StringSet.h"
25 
26 #include <optional>
27 #include <unordered_map>
28 
29 namespace llvm {
30 class DagInit;
31 class Init;
32 class Record;
33 } // namespace llvm
34 
35 namespace mlir {
36 namespace tblgen {
37 
38 // Mapping from TableGen Record to Operator wrapper object.
39 //
40 // We allocate each wrapper object in heap to make sure the pointer to it is
41 // valid throughout the lifetime of this map. This is important because this map
42 // is shared among multiple patterns to avoid creating the wrapper object for
43 // the same op again and again. But this map will continuously grow.
46 
47 class Pattern;
48 
49 // Wrapper class providing helper methods for accessing TableGen DAG leaves
50 // used inside Patterns. This class is lightweight and designed to be used like
51 // values.
52 //
53 // A TableGen DAG construct is of the syntax
54 // `(operator, arg0, arg1, ...)`.
55 //
56 // This class provides getters to retrieve `arg*` as tblgen:: wrapper objects
57 // for handy helper methods. It only works on `arg*`s that are not nested DAG
58 // constructs.
59 class DagLeaf {
60 public:
61  explicit DagLeaf(const llvm::Init *def) : def(def) {}
62 
63  // Returns true if this DAG leaf is not specified in the pattern. That is, it
64  // places no further constraints/transforms and just carries over the original
65  // value.
66  bool isUnspecified() const;
67 
68  // Returns true if this DAG leaf is matching an operand. That is, it specifies
69  // a type constraint.
70  bool isOperandMatcher() const;
71 
72  // Returns true if this DAG leaf is matching an attribute. That is, it
73  // specifies an attribute constraint.
74  bool isAttrMatcher() const;
75 
76  // Returns true if this DAG leaf is matching a property. That is, it
77  // specifies a property constraint.
78  bool isPropMatcher() const;
79 
80  // Returns true if this DAG leaf is describing a property. That is, it
81  // is a subclass of `Property` in tablegen.
82  bool isPropDefinition() const;
83 
84  // Returns true if this DAG leaf is wrapping native code call.
85  bool isNativeCodeCall() const;
86 
87  // Returns true if this DAG leaf is specifying a constant attribute.
88  bool isConstantAttr() const;
89 
90  // Returns true if this DAG leaf is specifying a constant property.
91  bool isConstantProp() const;
92 
93  // Returns true if this DAG leaf is specifying an enum case.
94  bool isEnumCase() const;
95 
96  // Returns true if this DAG leaf is specifying a string attribute.
97  bool isStringAttr() const;
98 
99  // Returns this DAG leaf as a constraint. Asserts if fails.
100  Constraint getAsConstraint() const;
101 
102  // Returns this DAG leaf as a property constraint. Asserts if fails. This
103  // allows access to the interface type.
105 
106  // Returns this DAG leaf as a property definition. Asserts if fails.
107  Property getAsProperty() const;
108 
109  // Returns this DAG leaf as an constant attribute. Asserts if fails.
111 
112  // Returns this DAG leaf as an constant property. Asserts if fails.
114 
115  // Returns this DAG leaf as an enum case.
116  // Precondition: isEnumCase()
117  EnumCase getAsEnumCase() const;
118 
119  // Returns the matching condition template inside this DAG leaf. Assumes the
120  // leaf is an operand/attribute matcher and asserts otherwise.
121  std::string getConditionTemplate() const;
122 
123  // Returns the native code call template inside this DAG leaf.
124  // Precondition: isNativeCodeCall()
125  StringRef getNativeCodeTemplate() const;
126 
127  // Returns the number of values will be returned by the native helper
128  // function.
129  // Precondition: isNativeCodeCall()
130  int getNumReturnsOfNativeCode() const;
131 
132  // Returns the string associated with the leaf.
133  // Precondition: isStringAttr()
134  std::string getStringAttr() const;
135 
136  void print(raw_ostream &os) const;
137 
138 private:
140  const void *getAsOpaquePointer() const { return def; }
141 
142  // Returns true if the TableGen Init `def` in this DagLeaf is a DefInit and
143  // also a subclass of the given `superclass`.
144  bool isSubClassOf(StringRef superclass) const;
145 
146  const llvm::Init *def;
147 };
148 
149 // Wrapper class providing helper methods for accessing TableGen DAG constructs
150 // used inside Patterns. This class is lightweight and designed to be used like
151 // values.
152 //
153 // A TableGen DAG construct is of the syntax
154 // `(operator, arg0, arg1, ...)`.
155 //
156 // When used inside Patterns, `operator` corresponds to some dialect op, or
157 // a known list of verbs that defines special transformation actions. This
158 // `arg*` can be a nested DAG construct. This class provides getters to
159 // retrieve `operator` and `arg*` as tblgen:: wrapper objects for handy helper
160 // methods.
161 //
162 // A null DagNode contains a nullptr and converts to false implicitly.
163 class DagNode {
164 public:
165  explicit DagNode(const llvm::DagInit *node) : node(node) {}
166 
167  // Implicit bool converter that returns true if this DagNode is not a null
168  // DagNode.
169  operator bool() const { return node != nullptr; }
170 
171  // Returns the symbol bound to this DAG node.
172  StringRef getSymbol() const;
173 
174  // Returns the operator wrapper object corresponding to the dialect op matched
175  // by this DAG. The operator wrapper will be queried from the given `mapper`
176  // and created in it if not existing.
177  Operator &getDialectOp(RecordOperatorMap *mapper) const;
178 
179  // Returns the number of operations recursively involved in the DAG tree
180  // rooted from this node.
181  int getNumOps() const;
182 
183  // Returns the number of immediate arguments to this DAG node.
184  int getNumArgs() const;
185 
186  // Returns true if the `index`-th argument is a nested DAG construct.
187  bool isNestedDagArg(unsigned index) const;
188 
189  // Gets the `index`-th argument as a nested DAG construct if possible. Returns
190  // null DagNode otherwise.
191  DagNode getArgAsNestedDag(unsigned index) const;
192 
193  // Gets the `index`-th argument as a DAG leaf.
194  DagLeaf getArgAsLeaf(unsigned index) const;
195 
196  // Returns the specified name of the `index`-th argument.
197  StringRef getArgName(unsigned index) const;
198 
199  // Returns true if this DAG construct means to replace with an existing SSA
200  // value.
201  bool isReplaceWithValue() const;
202 
203  // Returns whether this DAG represents the location of an op creation.
204  bool isLocationDirective() const;
205 
206  // Returns whether this DAG is a return type specifier.
207  bool isReturnTypeDirective() const;
208 
209  // Returns true if this DAG node is wrapping native code call.
210  bool isNativeCodeCall() const;
211 
212  // Returns whether this DAG is an `either` specifier.
213  bool isEither() const;
214 
215  // Returns whether this DAG is an `variadic` specifier.
216  bool isVariadic() const;
217 
218  // Returns true if this DAG node is an operation.
219  bool isOperation() const;
220 
221  // Returns the native code call template inside this DAG node.
222  // Precondition: isNativeCodeCall()
223  StringRef getNativeCodeTemplate() const;
224 
225  // Returns the number of values will be returned by the native helper
226  // function.
227  // Precondition: isNativeCodeCall()
228  int getNumReturnsOfNativeCode() const;
229 
230  void print(raw_ostream &os) const;
231 
232 private:
233  friend class SymbolInfoMap;
235  const void *getAsOpaquePointer() const { return node; }
236 
237  const llvm::DagInit *node; // nullptr means null DagNode
238 };
239 
240 // A class for maintaining information for symbols bound in patterns and
241 // provides methods for resolving them according to specific use cases.
242 //
243 // Symbols can be bound to
244 //
245 // * Op arguments and op results in the source pattern and
246 // * Op results in result patterns.
247 //
248 // Symbols can be referenced in result patterns and additional constraints to
249 // the pattern.
250 //
251 // For example, in
252 //
253 // ```
254 // def : Pattern<
255 // (SrcOp:$results1 $arg0, %arg1),
256 // [(ResOp1:$results2), (ResOp2 $results2 (ResOp3 $arg0, $arg1))]>;
257 // ```
258 //
259 // `$argN` is bound to the `SrcOp`'s N-th argument. `$results1` is bound to
260 // `SrcOp`. `$results2` is bound to `ResOp1`. $result2 is referenced to build
261 // `ResOp2`. `$arg0` and `$arg1` are referenced to build `ResOp3`.
262 //
263 // If a symbol binds to a multi-result op and it does not have the `__N`
264 // suffix, the symbol is expanded to represent all results generated by the
265 // multi-result op. If the symbol has a `__N` suffix, then it will expand to
266 // only the N-th *static* result as declared in ODS, and that can still
267 // corresponds to multiple *dynamic* values if the N-th *static* result is
268 // variadic.
269 //
270 // This class keeps track of such symbols and resolves them into their bound
271 // values in a suitable way.
273 public:
274  explicit SymbolInfoMap(ArrayRef<SMLoc> loc) : loc(loc) {}
275 
276  // Class for information regarding a symbol.
277  class SymbolInfo {
278  public:
279  // Returns a type string of a variable.
280  std::string getVarTypeStr(StringRef name) const;
281 
282  // Returns a string for defining a variable named as `name` to store the
283  // value bound by this symbol.
284  std::string getVarDecl(StringRef name) const;
285 
286  // Returns a string for defining an argument which passes the reference of
287  // the variable.
288  std::string getArgDecl(StringRef name) const;
289 
290  // Returns a variable name for the symbol named as `name`.
291  std::string getVarName(StringRef name) const;
292 
293  private:
294  // Allow SymbolInfoMap to access private methods.
295  friend class SymbolInfoMap;
296 
297  // Structure to uniquely distinguish different locations of the symbols.
298  //
299  // * If a symbol is defined as an operand of an operation, `dag` specifies
300  // the DAG of the operation, `operandIndexOrNumValues` specifies the
301  // operand index, and `variadicSubIndex` must be set to `std::nullopt`.
302  //
303  // * Properties not associated with an operation (e.g. as arguments to
304  // native code) have their corresponding PropConstraint stored in the
305  // `dag` field. This constraint is only used when
306  //
307  // * If a symbol is defined in a `variadic` DAG, `dag` specifies the DAG
308  // of the parent operation, `operandIndexOrNumValues` specifies the
309  // declared operand index of the variadic operand in the parent
310  // operation.
311  //
312  // - If the symbol is defined as a result of `variadic` DAG, the
313  // `variadicSubIndex` must be set to `std::nullopt`, which means that
314  // the symbol binds to the full operand range.
315  //
316  // - If the symbol is defined as a operand, the `variadicSubIndex` must
317  // be set to the index within the variadic sub-operand list.
318  //
319  // * If a symbol is defined in a `either` DAG, `dag` specifies the DAG
320  // of the parent operation, `operandIndexOrNumValues` specifies the
321  // operand index in the parent operation (not necessary the index in the
322  // DAG).
323  //
324  // * If a symbol is defined as a result, specifies the number of returning
325  // value.
326  //
327  // Example 1:
328  //
329  // def : Pat<(OpA $input0, $input1), ...>;
330  //
331  // $input0: (OpA, 0, nullopt)
332  // $input1: (OpA, 1, nullopt)
333  //
334  // Example 2:
335  //
336  // def : Pat<(OpB (variadic:$input0 $input0a, $input0b),
337  // (variadic:$input1 $input1a, $input1b, $input1c)),
338  // ...>;
339  //
340  // $input0: (OpB, 0, nullopt)
341  // $input0a: (OpB, 0, 0)
342  // $input0b: (OpB, 0, 1)
343  // $input1: (OpB, 1, nullopt)
344  // $input1a: (OpB, 1, 0)
345  // $input1b: (OpB, 1, 1)
346  // $input1c: (OpB, 1, 2)
347  //
348  // Example 3:
349  //
350  // def : Pat<(OpC $input0, (either $input1, $input2)), ...>;
351  //
352  // $input0: (OpC, 0, nullopt)
353  // $input1: (OpC, 1, nullopt)
354  // $input2: (OpC, 2, nullopt)
355  //
356  // Example 4:
357  //
358  // def ThreeResultOp : TEST_Op<...> {
359  // let results = (outs
360  // AnyType:$result1,
361  // AnyType:$result2,
362  // AnyType:$result3
363  // );
364  // }
365  //
366  // def : Pat<...,
367  // (ThreeResultOp:$result ...)>;
368  //
369  // $result: (nullptr, 3, nullopt)
370  //
371  struct DagAndConstant {
372  // DagNode and DagLeaf are accessed by value which means it can't be used
373  // as identifier here. Use an opaque pointer type instead.
374  const void *dag;
375  int operandIndexOrNumValues;
376  std::optional<int> variadicSubIndex;
377 
378  DagAndConstant(const void *dag, int operandIndexOrNumValues,
379  std::optional<int> variadicSubIndex)
380  : dag(dag), operandIndexOrNumValues(operandIndexOrNumValues),
381  variadicSubIndex(variadicSubIndex) {}
382 
383  bool operator==(const DagAndConstant &rhs) const {
384  return dag == rhs.dag &&
385  operandIndexOrNumValues == rhs.operandIndexOrNumValues &&
386  variadicSubIndex == rhs.variadicSubIndex;
387  }
388  };
389 
390  // What kind of entity this symbol represents:
391  // * Attr: op attribute
392  // * Prop: op property
393  // * Operand: op operand
394  // * Result: op result
395  // * Value: a value not attached to an op (e.g., from NativeCodeCall)
396  // * MultipleValues: a pack of values not attached to an op (e.g., from
397  // NativeCodeCall). This kind supports indexing.
398  enum class Kind : uint8_t {
399  Attr,
400  Prop,
401  Operand,
402  Result,
403  Value,
404  MultipleValues
405  };
406 
407  // Creates a SymbolInfo instance. `dagAndConstant` is only used for `Attr`
408  // and `Operand` so should be std::nullopt for `Result` and `Value` kind.
409  SymbolInfo(const Operator *op, Kind kind,
410  std::optional<DagAndConstant> dagAndConstant);
411 
412  // Static methods for creating SymbolInfo.
413  static SymbolInfo getAttr(const Operator *op, int index) {
414  return SymbolInfo(op, Kind::Attr,
415  DagAndConstant(nullptr, index, std::nullopt));
416  }
417  static SymbolInfo getAttr() {
418  return SymbolInfo(nullptr, Kind::Attr, std::nullopt);
419  }
420  static SymbolInfo getProp(const Operator *op, int index) {
421  return SymbolInfo(op, Kind::Prop,
422  DagAndConstant(nullptr, index, std::nullopt));
423  }
424  static SymbolInfo getProp(const PropConstraint *constraint) {
425  // -1 for anthe `operandIndexOrNumValues` is a sentinel value.
426  return SymbolInfo(nullptr, Kind::Prop,
427  DagAndConstant(constraint, -1, std::nullopt));
428  }
429  static SymbolInfo
430  getOperand(DagNode node, const Operator *op, int operandIndex,
431  std::optional<int> variadicSubIndex = std::nullopt) {
432  return SymbolInfo(op, Kind::Operand,
433  DagAndConstant(node.getAsOpaquePointer(), operandIndex,
434  variadicSubIndex));
435  }
436  static SymbolInfo getResult(const Operator *op) {
437  return SymbolInfo(op, Kind::Result, std::nullopt);
438  }
439  static SymbolInfo getValue() {
440  return SymbolInfo(nullptr, Kind::Value, std::nullopt);
441  }
442  static SymbolInfo getMultipleValues(int numValues) {
443  return SymbolInfo(nullptr, Kind::MultipleValues,
444  DagAndConstant(nullptr, numValues, std::nullopt));
445  }
446 
447  // Returns the number of static values this symbol corresponds to.
448  // A static value is an operand/result declared in ODS. Normally a symbol
449  // only represents one static value, but symbols bound to op results can
450  // represent more than one if the op is a multi-result op.
451  int getStaticValueCount() const;
452 
453  // Returns a string containing the C++ expression for referencing this
454  // symbol as a value (if this symbol represents one static value) or a value
455  // range (if this symbol represents multiple static values). `name` is the
456  // name of the C++ variable that this symbol bounds to. `index` should only
457  // be used for indexing results. `fmt` is used to format each value.
458  // `separator` is used to separate values if this is a value range.
459  std::string getValueAndRangeUse(StringRef name, int index, const char *fmt,
460  const char *separator) const;
461 
462  // Returns a string containing the C++ expression for referencing this
463  // symbol as a value range regardless of how many static values this symbol
464  // represents. `name` is the name of the C++ variable that this symbol
465  // bounds to. `index` should only be used for indexing results. `fmt` is
466  // used to format each value. `separator` is used to separate values in the
467  // range.
468  std::string getAllRangeUse(StringRef name, int index, const char *fmt,
469  const char *separator) const;
470 
471  // The argument index (for `Attr` and `Operand` only)
472  int getArgIndex() const { return dagAndConstant->operandIndexOrNumValues; }
473 
474  // The number of values in the MultipleValue
475  int getSize() const { return dagAndConstant->operandIndexOrNumValues; }
476 
477  // The variadic sub-operands index (for variadic `Operand` only)
478  std::optional<int> getVariadicSubIndex() const {
479  return dagAndConstant->variadicSubIndex;
480  }
481 
482  const Operator *op; // The op where the bound entity belongs
483  Kind kind; // The kind of the bound entity
484 
485  // The tuple of DagNode pointer and two constant values (for `Attr`,
486  // `Operand` and the size of MultipleValue symbol). Note that operands may
487  // be bound to the same symbol, use the DagNode and index to distinguish
488  // them. For `Attr` and MultipleValue, the Dag part will be nullptr.
489  std::optional<DagAndConstant> dagAndConstant;
490 
491  // Alternative name for the symbol. It is used in case the name
492  // is not unique. Applicable for `Operand` only.
493  std::optional<std::string> alternativeName;
494  };
495 
496  using BaseT = std::unordered_multimap<std::string, SymbolInfo>;
497 
498  // Iterators for accessing all symbols.
499  using iterator = BaseT::iterator;
500  iterator begin() { return symbolInfoMap.begin(); }
501  iterator end() { return symbolInfoMap.end(); }
502 
503  // Const iterators for accessing all symbols.
504  using const_iterator = BaseT::const_iterator;
505  const_iterator begin() const { return symbolInfoMap.begin(); }
506  const_iterator end() const { return symbolInfoMap.end(); }
507 
508  // Binds the given `symbol` to the `argIndex`-th argument to the given `op`.
509  // Returns false if `symbol` is already bound and symbols are not operands.
510  bool bindOpArgument(DagNode node, StringRef symbol, const Operator &op,
511  int argIndex,
512  std::optional<int> variadicSubIndex = std::nullopt);
513 
514  // Binds the given `symbol` to the results the given `op`. Returns false if
515  // `symbol` is already bound.
516  bool bindOpResult(StringRef symbol, const Operator &op);
517 
518  // A helper function for dispatching target value binding functions.
519  bool bindValues(StringRef symbol, int numValues = 1);
520 
521  // Registers the given `symbol` as bound to the Value(s). Returns false if
522  // `symbol` is already bound.
523  bool bindValue(StringRef symbol);
524 
525  // Registers the given `symbol` as bound to a MultipleValue. Return false if
526  // `symbol` is already bound.
527  bool bindMultipleValues(StringRef symbol, int numValues);
528 
529  // Registers the given `symbol` as bound to an attr. Returns false if `symbol`
530  // is already bound.
531  bool bindAttr(StringRef symbol);
532 
533  // Registers the given `symbol` as bound to a property that satisfies the
534  // given `constraint`. `constraint` must name a concrete interface type.
535  bool bindProp(StringRef symbol, const PropConstraint &constraint);
536 
537  // Returns true if the given `symbol` is bound.
538  bool contains(StringRef symbol) const;
539 
540  // Returns an iterator to the information of the given symbol named as `key`.
541  const_iterator find(StringRef key) const;
542 
543  // Returns an iterator to the information of the given symbol named as `key`,
544  // with index `argIndex` for operator `op`.
545  const_iterator findBoundSymbol(StringRef key, DagNode node,
546  const Operator &op, int argIndex,
547  std::optional<int> variadicSubIndex) const;
548  const_iterator findBoundSymbol(StringRef key,
549  const SymbolInfo &symbolInfo) const;
550 
551  // Returns the bounds of a range that includes all the elements which
552  // bind to the `key`.
553  std::pair<iterator, iterator> getRangeOfEqualElements(StringRef key);
554 
555  // Returns number of times symbol named as `key` was used.
556  int count(StringRef key) const;
557 
558  // Returns the number of static values of the given `symbol` corresponds to.
559  // A static value is an operand/result declared in ODS. Normally a symbol only
560  // represents one static value, but symbols bound to op results can represent
561  // more than one if the op is a multi-result op.
562  int getStaticValueCount(StringRef symbol) const;
563 
564  // Returns a string containing the C++ expression for referencing this
565  // symbol as a value (if this symbol represents one static value) or a value
566  // range (if this symbol represents multiple static values). `fmt` is used to
567  // format each value. `separator` is used to separate values if `symbol`
568  // represents a value range.
569  std::string getValueAndRangeUse(StringRef symbol, const char *fmt = "{0}",
570  const char *separator = ", ") const;
571 
572  // Returns a string containing the C++ expression for referencing this
573  // symbol as a value range regardless of how many static values this symbol
574  // represents. `fmt` is used to format each value. `separator` is used to
575  // separate values in the range.
576  std::string getAllRangeUse(StringRef symbol, const char *fmt = "{0}",
577  const char *separator = ", ") const;
578 
579  // Assign alternative unique names to Operands that have equal names.
581 
582  // Splits the given `symbol` into a value pack name and an index. Returns the
583  // value pack name and writes the index to `index` on success. Returns
584  // `symbol` itself if it does not contain an index.
585  //
586  // We can use `name__N` to access the `N`-th value in the value pack bound to
587  // `name`. `name` is typically the results of an multi-result op.
588  static StringRef getValuePackName(StringRef symbol, int *index = nullptr);
589 
590 private:
591  BaseT symbolInfoMap;
592 
593  // Pattern instantiation location. This is intended to be used as parameter
594  // to PrintFatalError() to report errors.
595  ArrayRef<SMLoc> loc;
596 };
597 
598 // Wrapper class providing helper methods for accessing MLIR Pattern defined
599 // in TableGen. This class should closely reflect what is defined as class
600 // `Pattern` in TableGen. This class contains maps so it is not intended to be
601 // used as values.
602 class Pattern {
603 public:
604  explicit Pattern(const llvm::Record *def, RecordOperatorMap *mapper);
605 
606  // Returns the source pattern to match.
607  DagNode getSourcePattern() const;
608 
609  // Returns the number of result patterns generated by applying this rewrite
610  // rule.
611  int getNumResultPatterns() const;
612 
613  // Returns the DAG tree root node of the `index`-th result pattern.
614  DagNode getResultPattern(unsigned index) const;
615 
616  // Collects all symbols bound in the source pattern into `infoMap`.
618 
619  // Collects all symbols bound in result patterns into `infoMap`.
621 
622  // Returns the op that the root node of the source pattern matches.
623  const Operator &getSourceRootOp();
624 
625  // Returns the operator wrapper object corresponding to the given `node`'s DAG
626  // operator.
628 
629  // Returns the constraints.
630  std::vector<AppliedConstraint> getConstraints() const;
631 
632  // Returns the number of supplemental auxiliary patterns generated by applying
633  // this rewrite rule.
634  int getNumSupplementalPatterns() const;
635 
636  // Returns the DAG tree root node of the `index`-th supplemental result
637  // pattern.
638  DagNode getSupplementalPattern(unsigned index) const;
639 
640  // Returns the benefit score of the pattern.
641  int getBenefit() const;
642 
643  using IdentifierLine = std::pair<StringRef, unsigned>;
644 
645  // Returns the file location of the pattern (buffer identifier + line number
646  // pair).
647  std::vector<IdentifierLine> getLocation() const;
648 
649  // Recursively collects all bound symbols inside the DAG tree rooted
650  // at `tree` and updates the given `infoMap`.
651  void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
652  bool isSrcPattern);
653 
654 private:
655  // Helper function to verify variable binding.
656  void verifyBind(bool result, StringRef symbolName);
657 
658  // The TableGen definition of this pattern.
659  const llvm::Record &def;
660 
661  // All operators.
662  // TODO: we need a proper context manager, like MLIRContext, for managing the
663  // lifetime of shared entities.
664  RecordOperatorMap *recordOpMap;
665 };
666 
667 } // namespace tblgen
668 } // namespace mlir
669 
670 namespace llvm {
671 template <>
672 struct DenseMapInfo<mlir::tblgen::DagNode> {
674  return mlir::tblgen::DagNode(
676  }
678  return mlir::tblgen::DagNode(
680  }
681  static unsigned getHashValue(mlir::tblgen::DagNode node) {
682  return llvm::hash_value(node.getAsOpaquePointer());
683  }
685  return lhs.node == rhs.node;
686  }
687 };
688 
689 template <>
690 struct DenseMapInfo<mlir::tblgen::DagLeaf> {
692  return mlir::tblgen::DagLeaf(
694  }
696  return mlir::tblgen::DagLeaf(
698  }
699  static unsigned getHashValue(mlir::tblgen::DagLeaf leaf) {
700  return llvm::hash_value(leaf.getAsOpaquePointer());
701  }
703  return lhs.def == rhs.def;
704  }
705 };
706 } // namespace llvm
707 
708 #endif // MLIR_TABLEGEN_PATTERN_H_
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Constraint getAsConstraint() const
Definition: Pattern.cpp:76
bool isNativeCodeCall() const
Definition: Pattern.cpp:64
bool isPropMatcher() const
Definition: Pattern.cpp:54
int getNumReturnsOfNativeCode() const
Definition: Pattern.cpp:116
ConstantAttr getAsConstantAttr() const
Definition: Pattern.cpp:92
void print(raw_ostream &os) const
Definition: Pattern.cpp:131
std::string getStringAttr() const
Definition: Pattern.cpp:121
Property getAsProperty() const
Definition: Pattern.cpp:87
bool isEnumCase() const
Definition: Pattern.cpp:70
StringRef getNativeCodeTemplate() const
Definition: Pattern.cpp:111
DagLeaf(const llvm::Init *def)
Definition: Pattern.h:61
std::string getConditionTemplate() const
Definition: Pattern.cpp:107
bool isConstantProp() const
Definition: Pattern.cpp:72
PropConstraint getAsPropConstraint() const
Definition: Pattern.cpp:82
ConstantProp getAsConstantProp() const
Definition: Pattern.cpp:102
bool isUnspecified() const
Definition: Pattern.cpp:40
EnumCase getAsEnumCase() const
Definition: Pattern.cpp:97
bool isAttrMatcher() const
Definition: Pattern.cpp:49
bool isOperandMatcher() const
Definition: Pattern.cpp:44
bool isPropDefinition() const
Definition: Pattern.cpp:59
bool isConstantAttr() const
Definition: Pattern.cpp:68
bool isStringAttr() const
Definition: Pattern.cpp:74
bool isReturnTypeDirective() const
Definition: Pattern.cpp:216
bool isLocationDirective() const
Definition: Pattern.cpp:211
bool isReplaceWithValue() const
Definition: Pattern.cpp:206
DagNode getArgAsNestedDag(unsigned index) const
Definition: Pattern.cpp:193
bool isOperation() const
Definition: Pattern.cpp:146
DagLeaf getArgAsLeaf(unsigned index) const
Definition: Pattern.cpp:197
int getNumReturnsOfNativeCode() const
Definition: Pattern.cpp:159
StringRef getNativeCodeTemplate() const
Definition: Pattern.cpp:152
void print(raw_ostream &os) const
Definition: Pattern.cpp:231
int getNumOps() const
Definition: Pattern.cpp:176
Operator & getDialectOp(RecordOperatorMap *mapper) const
Definition: Pattern.cpp:168
bool isVariadic() const
Definition: Pattern.cpp:226
bool isNativeCodeCall() const
Definition: Pattern.cpp:140
bool isEither() const
Definition: Pattern.cpp:221
bool isNestedDagArg(unsigned index) const
Definition: Pattern.cpp:189
StringRef getSymbol() const
Definition: Pattern.cpp:166
int getNumArgs() const
Definition: Pattern.cpp:187
DagNode(const llvm::DagInit *node)
Definition: Pattern.h:165
StringRef getArgName(unsigned index) const
Definition: Pattern.cpp:202
Wrapper class that contains a MLIR op's information (e.g., operands, attributes) defined in TableGen ...
Definition: Operator.h:77
int getBenefit() const
int getNumResultPatterns() const
Definition: Pattern.cpp:687
std::vector< IdentifierLine > getLocation() const
Definition: Pattern.cpp:774
DagNode getSourcePattern() const
Definition: Pattern.cpp:683
const Operator & getSourceRootOp()
Definition: Pattern.cpp:716
std::vector< AppliedConstraint > getConstraints() const
Definition: Pattern.cpp:724
DagNode getResultPattern(unsigned index) const
Definition: Pattern.cpp:692
void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, bool isSrcPattern)
Definition: Pattern.cpp:794
Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
Operator & getDialectOp(DagNode node)
Definition: Pattern.cpp:720
DagNode getSupplementalPattern(unsigned index) const
Definition: Pattern.cpp:757
void collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap)
Definition: Pattern.cpp:697
void collectResultPatternBoundSymbols(SymbolInfoMap &infoMap)
Definition: Pattern.cpp:707
int getNumSupplementalPatterns() const
Definition: Pattern.cpp:752
std::pair< StringRef, unsigned > IdentifierLine
Definition: Pattern.h:643
std::string getArgDecl(StringRef name) const
Definition: Pattern.cpp:327
std::string getVarName(StringRef name) const
Definition: Pattern.cpp:274
std::string getVarTypeStr(StringRef name) const
Definition: Pattern.cpp:278
std::string getVarDecl(StringRef name) const
Definition: Pattern.cpp:320
BaseT::iterator iterator
Definition: Pattern.h:499
static StringRef getValuePackName(StringRef symbol, int *index=nullptr)
Definition: Pattern.cpp:240
const_iterator begin() const
Definition: Pattern.h:505
int count(StringRef key) const
Definition: Pattern.cpp:600
const_iterator find(StringRef key) const
Definition: Pattern.cpp:566
bool bindMultipleValues(StringRef symbol, int numValues)
Definition: Pattern.cpp:543
bool bindOpArgument(DagNode node, StringRef symbol, const Operator &op, int argIndex, std::optional< int > variadicSubIndex=std::nullopt)
Definition: Pattern.cpp:489
std::string getAllRangeUse(StringRef symbol, const char *fmt="{0}", const char *separator=", ") const
Definition: Pattern.cpp:631
bool bindValues(StringRef symbol, int numValues=1)
Definition: Pattern.cpp:531
bool bindAttr(StringRef symbol)
Definition: Pattern.cpp:550
bool bindProp(StringRef symbol, const PropConstraint &constraint)
Definition: Pattern.cpp:555
SymbolInfoMap(ArrayRef< SMLoc > loc)
Definition: Pattern.h:274
bool bindValue(StringRef symbol)
Definition: Pattern.cpp:538
const_iterator findBoundSymbol(StringRef key, DagNode node, const Operator &op, int argIndex, std::optional< int > variadicSubIndex) const
Definition: Pattern.cpp:573
std::pair< iterator, iterator > getRangeOfEqualElements(StringRef key)
Definition: Pattern.cpp:594
int getStaticValueCount(StringRef symbol) const
Definition: Pattern.cpp:605
bool contains(StringRef symbol) const
Definition: Pattern.cpp:562
BaseT::const_iterator const_iterator
Definition: Pattern.h:504
bool bindOpResult(StringRef symbol, const Operator &op)
Definition: Pattern.cpp:524
std::string getValueAndRangeUse(StringRef symbol, const char *fmt="{0}", const char *separator=", ") const
Definition: Pattern.cpp:616
const_iterator end() const
Definition: Pattern.h:506
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition: CallGraph.h:229
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo)
Include the generated interface declarations.
bool operator==(StringAttr lhs, std::nullptr_t)
Define comparisons for StringAttr against nullptr and itself to avoid the StringRef overloads from be...
static bool isEqual(mlir::tblgen::DagLeaf lhs, mlir::tblgen::DagLeaf rhs)
Definition: Pattern.h:702
static mlir::tblgen::DagLeaf getEmptyKey()
Definition: Pattern.h:691
static mlir::tblgen::DagLeaf getTombstoneKey()
Definition: Pattern.h:695
static unsigned getHashValue(mlir::tblgen::DagLeaf leaf)
Definition: Pattern.h:699
static bool isEqual(mlir::tblgen::DagNode lhs, mlir::tblgen::DagNode rhs)
Definition: Pattern.h:684
static mlir::tblgen::DagNode getEmptyKey()
Definition: Pattern.h:673
static mlir::tblgen::DagNode getTombstoneKey()
Definition: Pattern.h:677
static unsigned getHashValue(mlir::tblgen::DagNode node)
Definition: Pattern.h:681