MLIR  21.0.0git
Predicate.h
Go to the documentation of this file.
1 //===- Predicate.h - Pattern predicates -------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains definitions for "predicates" used when converting PDL into
10 // a matcher tree. Predicates are composed of three different parts:
11 //
12 // * Positions
13 // - A position refers to a specific location on the input DAG, i.e. an
14 // existing MLIR entity being matched. These can be attributes, operands,
15 // operations, results, and types. Each position also defines a relation to
16 // its parent. For example, the operand `[0] -> 1` has a parent operation
17 // position `[0]`. The attribute `[0, 1] -> "myAttr"` has parent operation
18 // position of `[0, 1]`. The operation `[0, 1]` has a parent operand edge
19 // `[0] -> 1` (i.e. it is the defining op of operand 1). The only position
20 // without a parent is `[0]`, which refers to the root operation.
21 // * Questions
22 // - A question refers to a query on a specific positional value. For
23 // example, an operation name question checks the name of an operation
24 // position.
25 // * Answers
26 // - An answer is the expected result of a question. For example, when
27 // matching an operation with the name "foo.op". The question would be an
28 // operation name question, with an expected answer of "foo.op".
29 //
30 //===----------------------------------------------------------------------===//
31 
32 #ifndef MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_
33 #define MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_
34 
35 #include "mlir/IR/MLIRContext.h"
37 #include "mlir/IR/PatternMatch.h"
38 #include "mlir/IR/Types.h"
39 
40 namespace mlir {
41 namespace pdl_to_pdl_interp {
42 namespace Predicates {
43 /// An enumeration of the kinds of predicates.
44 enum Kind : unsigned {
45  /// Positions, ordered by decreasing priority.
58 
59  // Questions, ordered by dependency and decreasing priority.
70 
71  // Answers.
78 };
79 } // namespace Predicates
80 
81 /// Base class for all predicates, used to allow efficient pointer comparison.
82 template <typename ConcreteT, typename BaseT, typename Key,
84 class PredicateBase : public BaseT {
85 public:
86  using KeyTy = Key;
88 
89  template <typename KeyT>
90  explicit PredicateBase(KeyT &&key)
91  : BaseT(Kind), key(std::forward<KeyT>(key)) {}
92 
93  /// Get an instance of this position.
94  template <typename... Args>
95  static ConcreteT *get(StorageUniquer &uniquer, Args &&...args) {
96  return uniquer.get<ConcreteT>(/*initFn=*/{}, std::forward<Args>(args)...);
97  }
98 
99  /// Construct an instance with the given storage allocator.
100  template <typename KeyT>
101  static ConcreteT *construct(StorageUniquer::StorageAllocator &alloc,
102  KeyT &&key) {
103  return new (alloc.allocate<ConcreteT>()) ConcreteT(std::forward<KeyT>(key));
104  }
105 
106  /// Utility methods required by the storage allocator.
107  bool operator==(const KeyTy &key) const { return this->key == key; }
108  static bool classof(const BaseT *pred) { return pred->getKind() == Kind; }
109 
110  /// Return the key value of this predicate.
111  const KeyTy &getValue() const { return key; }
112 
113 protected:
115 };
116 
117 /// Base storage for simple predicates that only unique with the kind.
118 template <typename ConcreteT, typename BaseT, Predicates::Kind Kind>
119 class PredicateBase<ConcreteT, BaseT, void, Kind> : public BaseT {
120 public:
122 
123  explicit PredicateBase() : BaseT(Kind) {}
124 
125  static ConcreteT *get(StorageUniquer &uniquer) {
126  return uniquer.get<ConcreteT>();
127  }
128  static bool classof(const BaseT *pred) { return pred->getKind() == Kind; }
129 };
130 
131 //===----------------------------------------------------------------------===//
132 // Positions
133 //===----------------------------------------------------------------------===//
134 
135 struct OperationPosition;
136 
137 /// A position describes a value on the input IR on which a predicate may be
138 /// applied, such as an operation or attribute. This enables re-use between
139 /// predicates, and assists generating bytecode and memory management.
140 ///
141 /// Operation positions form the base of other positions, which are formed
142 /// relative to a parent operation. Operations are anchored at Operand nodes,
143 /// except for the root operation which is parentless.
145 public:
146  explicit Position(Predicates::Kind kind) : kind(kind) {}
147  virtual ~Position();
148 
149  /// Returns the depth of the first ancestor operation position.
150  unsigned getOperationDepth() const;
151 
152  /// Returns the parent position. The root operation position has no parent.
153  Position *getParent() const { return parent; }
154 
155  /// Returns the kind of this position.
156  Predicates::Kind getKind() const { return kind; }
157 
158 protected:
159  /// Link to the parent position.
160  Position *parent = nullptr;
161 
162 private:
163  /// The kind of this position.
164  Predicates::Kind kind;
165 };
166 
167 //===----------------------------------------------------------------------===//
168 // AttributePosition
169 //===----------------------------------------------------------------------===//
170 
171 /// A position describing an attribute of an operation.
173  : public PredicateBase<AttributePosition, Position,
174  std::pair<OperationPosition *, StringAttr>,
175  Predicates::AttributePos> {
176  explicit AttributePosition(const KeyTy &key);
177 
178  /// Returns the attribute name of this position.
179  StringAttr getName() const { return key.second; }
180 };
181 
182 //===----------------------------------------------------------------------===//
183 // AttributeLiteralPosition
184 //===----------------------------------------------------------------------===//
185 
186 /// A position describing a literal attribute.
188  : public PredicateBase<AttributeLiteralPosition, Position, Attribute,
189  Predicates::AttributeLiteralPos> {
191 };
192 
193 //===----------------------------------------------------------------------===//
194 // ForEachPosition
195 //===----------------------------------------------------------------------===//
196 
197 /// A position describing an iterative choice of an operation.
198 struct ForEachPosition : public PredicateBase<ForEachPosition, Position,
199  std::pair<Position *, unsigned>,
200  Predicates::ForEachPos> {
201  explicit ForEachPosition(const KeyTy &key) : Base(key) { parent = key.first; }
202 
203  /// Returns the ID, for differentiating various loops.
204  /// For upward traversals, this is the index of the root.
205  unsigned getID() const { return key.second; }
206 };
207 
208 //===----------------------------------------------------------------------===//
209 // OperandPosition
210 //===----------------------------------------------------------------------===//
211 
212 /// A position describing an operand of an operation.
214  : public PredicateBase<OperandPosition, Position,
215  std::pair<OperationPosition *, unsigned>,
216  Predicates::OperandPos> {
217  explicit OperandPosition(const KeyTy &key);
218 
219  /// Returns the operand number of this position.
220  unsigned getOperandNumber() const { return key.second; }
221 };
222 
223 //===----------------------------------------------------------------------===//
224 // OperandGroupPosition
225 //===----------------------------------------------------------------------===//
226 
227 /// A position describing an operand group of an operation.
229  : public PredicateBase<
230  OperandGroupPosition, Position,
231  std::tuple<OperationPosition *, std::optional<unsigned>, bool>,
232  Predicates::OperandGroupPos> {
233  explicit OperandGroupPosition(const KeyTy &key);
234 
235  /// Returns a hash suitable for the given keytype.
236  static llvm::hash_code hashKey(const KeyTy &key) {
237  return llvm::hash_value(key);
238  }
239 
240  /// Returns the group number of this position. If std::nullopt, this group
241  /// refers to all operands.
242  std::optional<unsigned> getOperandGroupNumber() const {
243  return std::get<1>(key);
244  }
245 
246  /// Returns if the operand group has unknown size. If false, the operand group
247  /// has at max one element.
248  bool isVariadic() const { return std::get<2>(key); }
249 };
250 
251 //===----------------------------------------------------------------------===//
252 // OperationPosition
253 //===----------------------------------------------------------------------===//
254 
255 /// An operation position describes an operation node in the IR. Other position
256 /// kinds are formed with respect to an operation position.
257 struct OperationPosition : public PredicateBase<OperationPosition, Position,
258  std::pair<Position *, unsigned>,
259  Predicates::OperationPos> {
260  explicit OperationPosition(const KeyTy &key) : Base(key) {
261  parent = key.first;
262  }
263 
264  /// Returns a hash suitable for the given keytype.
265  static llvm::hash_code hashKey(const KeyTy &key) {
266  return llvm::hash_value(key);
267  }
268 
269  /// Gets the root position.
271  return Base::get(uniquer, nullptr, 0);
272  }
273 
274  /// Gets an operation position with the given parent.
276  return Base::get(uniquer, parent, parent->getOperationDepth() + 1);
277  }
278 
279  /// Returns the depth of this position.
280  unsigned getDepth() const { return key.second; }
281 
282  /// Returns if this operation position corresponds to the root.
283  bool isRoot() const { return getDepth() == 0; }
284 
285  /// Returns if this operation represents an operand defining op.
286  bool isOperandDefiningOp() const;
287 };
288 
289 //===----------------------------------------------------------------------===//
290 // ConstraintPosition
291 //===----------------------------------------------------------------------===//
292 
293 struct ConstraintQuestion;
294 
295 /// A position describing the result of a native constraint. It saves the
296 /// corresponding ConstraintQuestion and result index to enable referring
297 /// back to them
299  : public PredicateBase<ConstraintPosition, Position,
300  std::pair<ConstraintQuestion *, unsigned>,
301  Predicates::ConstraintResultPos> {
303 
304  /// Returns the ConstraintQuestion to enable keeping track of the native
305  /// constraint this position stems from.
306  ConstraintQuestion *getQuestion() const { return key.first; }
307 
308  // Returns the result index of this position
309  unsigned getIndex() const { return key.second; }
310 };
311 
312 //===----------------------------------------------------------------------===//
313 // ResultPosition
314 //===----------------------------------------------------------------------===//
315 
316 /// A position describing a result of an operation.
318  : public PredicateBase<ResultPosition, Position,
319  std::pair<OperationPosition *, unsigned>,
320  Predicates::ResultPos> {
321  explicit ResultPosition(const KeyTy &key) : Base(key) { parent = key.first; }
322 
323  /// Returns the result number of this position.
324  unsigned getResultNumber() const { return key.second; }
325 };
326 
327 //===----------------------------------------------------------------------===//
328 // ResultGroupPosition
329 //===----------------------------------------------------------------------===//
330 
331 /// A position describing a result group of an operation.
333  : public PredicateBase<
334  ResultGroupPosition, Position,
335  std::tuple<OperationPosition *, std::optional<unsigned>, bool>,
336  Predicates::ResultGroupPos> {
337  explicit ResultGroupPosition(const KeyTy &key) : Base(key) {
338  parent = std::get<0>(key);
339  }
340 
341  /// Returns a hash suitable for the given keytype.
342  static llvm::hash_code hashKey(const KeyTy &key) {
343  return llvm::hash_value(key);
344  }
345 
346  /// Returns the group number of this position. If std::nullopt, this group
347  /// refers to all results.
348  std::optional<unsigned> getResultGroupNumber() const {
349  return std::get<1>(key);
350  }
351 
352  /// Returns if the result group has unknown size. If false, the result group
353  /// has at max one element.
354  bool isVariadic() const { return std::get<2>(key); }
355 };
356 
357 //===----------------------------------------------------------------------===//
358 // TypePosition
359 //===----------------------------------------------------------------------===//
360 
361 /// A position describing the result type of an entity, i.e. an Attribute,
362 /// Operand, Result, etc.
363 struct TypePosition : public PredicateBase<TypePosition, Position, Position *,
364  Predicates::TypePos> {
365  explicit TypePosition(const KeyTy &key) : Base(key) {
368  "expected parent to be an attribute, operand, or result");
369  parent = key;
370  }
371 };
372 
373 //===----------------------------------------------------------------------===//
374 // TypeLiteralPosition
375 //===----------------------------------------------------------------------===//
376 
377 /// A position describing a literal type or type range. The value is stored as
378 /// either a TypeAttr, or an ArrayAttr of TypeAttr.
380  : public PredicateBase<TypeLiteralPosition, Position, Attribute,
381  Predicates::TypeLiteralPos> {
383 };
384 
385 //===----------------------------------------------------------------------===//
386 // UsersPosition
387 //===----------------------------------------------------------------------===//
388 
389 /// A position describing the users of a value or a range of values. The second
390 /// value in the key indicates whether we choose users of a representative for
391 /// a range (this is true, e.g., in the upward traversals).
393  : public PredicateBase<UsersPosition, Position, std::pair<Position *, bool>,
394  Predicates::UsersPos> {
395  explicit UsersPosition(const KeyTy &key) : Base(key) { parent = key.first; }
396 
397  /// Returns a hash suitable for the given keytype.
398  static llvm::hash_code hashKey(const KeyTy &key) {
399  return llvm::hash_value(key);
400  }
401 
402  /// Indicates whether to compute a range of a representative.
403  bool useRepresentative() const { return key.second; }
404 };
405 
406 //===----------------------------------------------------------------------===//
407 // Qualifiers
408 //===----------------------------------------------------------------------===//
409 
410 /// An ordinal predicate consists of a "Question" and a set of acceptable
411 /// "Answers" (later converted to ordinal values). A predicate will query some
412 /// property of a positional value and decide what to do based on the result.
413 ///
414 /// This makes top-level predicate representations ordinal (SwitchOp). Later,
415 /// predicates that end up with only one acceptable answer (including all
416 /// boolean kinds) will be converted to boolean predicates (PredicateOp) in the
417 /// matcher.
418 ///
419 /// For simplicity, both are represented as "qualifiers", with a base kind and
420 /// perhaps additional properties. For example, all OperationName predicates ask
421 /// the same question, but GenericConstraint predicates may ask different ones.
423 public:
424  explicit Qualifier(Predicates::Kind kind) : kind(kind) {}
425 
426  /// Returns the kind of this qualifier.
427  Predicates::Kind getKind() const { return kind; }
428 
429 private:
430  /// The kind of this position.
431  Predicates::Kind kind;
432 };
433 
434 //===----------------------------------------------------------------------===//
435 // Answers
436 //===----------------------------------------------------------------------===//
437 
438 /// An Answer representing an `Attribute` value.
440  : public PredicateBase<AttributeAnswer, Qualifier, Attribute,
441  Predicates::AttributeAnswer> {
442  using Base::Base;
443 };
444 
445 /// An Answer representing an `OperationName` value.
447  : public PredicateBase<OperationNameAnswer, Qualifier, OperationName,
448  Predicates::OperationNameAnswer> {
449  using Base::Base;
450 };
451 
452 /// An Answer representing a boolean `true` value.
454  : PredicateBase<TrueAnswer, Qualifier, void, Predicates::TrueAnswer> {
455  using Base::Base;
456 };
457 
458 /// An Answer representing a boolean 'false' value.
460  : PredicateBase<FalseAnswer, Qualifier, void, Predicates::FalseAnswer> {
461  using Base::Base;
462 };
463 
464 /// An Answer representing a `Type` value. The value is stored as either a
465 /// TypeAttr, or an ArrayAttr of TypeAttr.
466 struct TypeAnswer : public PredicateBase<TypeAnswer, Qualifier, Attribute,
467  Predicates::TypeAnswer> {
468  using Base::Base;
469 };
470 
471 /// An Answer representing an unsigned value.
473  : public PredicateBase<UnsignedAnswer, Qualifier, unsigned,
474  Predicates::UnsignedAnswer> {
475  using Base::Base;
476 };
477 
478 //===----------------------------------------------------------------------===//
479 // Questions
480 //===----------------------------------------------------------------------===//
481 
482 /// Compare an `Attribute` to a constant value.
484  : public PredicateBase<AttributeQuestion, Qualifier, void,
485  Predicates::AttributeQuestion> {};
486 
487 /// Apply a parameterized constraint to multiple position values and possibly
488 /// produce results.
490  : public PredicateBase<
491  ConstraintQuestion, Qualifier,
492  std::tuple<StringRef, ArrayRef<Position *>, ArrayRef<Type>, bool>,
493  Predicates::ConstraintQuestion> {
494  using Base::Base;
495 
496  /// Return the name of the constraint.
497  StringRef getName() const { return std::get<0>(key); }
498 
499  /// Return the arguments of the constraint.
500  ArrayRef<Position *> getArgs() const { return std::get<1>(key); }
501 
502  /// Return the result types of the constraint.
503  ArrayRef<Type> getResultTypes() const { return std::get<2>(key); }
504 
505  /// Return the negation status of the constraint.
506  bool getIsNegated() const { return std::get<3>(key); }
507 
508  /// Construct an instance with the given storage allocator.
510  KeyTy key) {
511  return Base::construct(alloc, KeyTy{alloc.copyInto(std::get<0>(key)),
512  alloc.copyInto(std::get<1>(key)),
513  alloc.copyInto(std::get<2>(key)),
514  std::get<3>(key)});
515  }
516 
517  /// Returns a hash suitable for the given keytype.
518  static llvm::hash_code hashKey(const KeyTy &key) {
519  return llvm::hash_value(key);
520  }
521 };
522 
523 /// Compare the equality of two values.
525  : public PredicateBase<EqualToQuestion, Qualifier, Position *,
526  Predicates::EqualToQuestion> {
527  using Base::Base;
528 };
529 
530 /// Compare a positional value with null, i.e. check if it exists.
532  : public PredicateBase<IsNotNullQuestion, Qualifier, void,
533  Predicates::IsNotNullQuestion> {};
534 
535 /// Compare the number of operands of an operation with a known value.
537  : public PredicateBase<OperandCountQuestion, Qualifier, void,
538  Predicates::OperandCountQuestion> {};
540  : public PredicateBase<OperandCountAtLeastQuestion, Qualifier, void,
541  Predicates::OperandCountAtLeastQuestion> {};
542 
543 /// Compare the name of an operation with a known value.
545  : public PredicateBase<OperationNameQuestion, Qualifier, void,
546  Predicates::OperationNameQuestion> {};
547 
548 /// Compare the number of results of an operation with a known value.
550  : public PredicateBase<ResultCountQuestion, Qualifier, void,
551  Predicates::ResultCountQuestion> {};
553  : public PredicateBase<ResultCountAtLeastQuestion, Qualifier, void,
554  Predicates::ResultCountAtLeastQuestion> {};
555 
556 /// Compare the type of an attribute or value with a known type.
557 struct TypeQuestion : public PredicateBase<TypeQuestion, Qualifier, void,
558  Predicates::TypeQuestion> {};
559 
560 //===----------------------------------------------------------------------===//
561 // PredicateUniquer
562 //===----------------------------------------------------------------------===//
563 
564 /// This class provides a storage uniquer that is used to allocate predicate
565 /// instances.
567 public:
569  // Register the types of Positions with the uniquer.
570  registerParametricStorageType<AttributePosition>();
571  registerParametricStorageType<AttributeLiteralPosition>();
572  registerParametricStorageType<ConstraintPosition>();
573  registerParametricStorageType<ForEachPosition>();
574  registerParametricStorageType<OperandPosition>();
575  registerParametricStorageType<OperandGroupPosition>();
576  registerParametricStorageType<OperationPosition>();
577  registerParametricStorageType<ResultPosition>();
578  registerParametricStorageType<ResultGroupPosition>();
579  registerParametricStorageType<TypePosition>();
580  registerParametricStorageType<TypeLiteralPosition>();
581  registerParametricStorageType<UsersPosition>();
582 
583  // Register the types of Questions with the uniquer.
584  registerParametricStorageType<AttributeAnswer>();
585  registerParametricStorageType<OperationNameAnswer>();
586  registerParametricStorageType<TypeAnswer>();
587  registerParametricStorageType<UnsignedAnswer>();
588  registerSingletonStorageType<FalseAnswer>();
589  registerSingletonStorageType<TrueAnswer>();
590 
591  // Register the types of Answers with the uniquer.
592  registerParametricStorageType<ConstraintQuestion>();
593  registerParametricStorageType<EqualToQuestion>();
594  registerSingletonStorageType<AttributeQuestion>();
595  registerSingletonStorageType<IsNotNullQuestion>();
596  registerSingletonStorageType<OperandCountQuestion>();
597  registerSingletonStorageType<OperandCountAtLeastQuestion>();
598  registerSingletonStorageType<OperationNameQuestion>();
599  registerSingletonStorageType<ResultCountQuestion>();
600  registerSingletonStorageType<ResultCountAtLeastQuestion>();
601  registerSingletonStorageType<TypeQuestion>();
602  }
603 };
604 
605 //===----------------------------------------------------------------------===//
606 // PredicateBuilder
607 //===----------------------------------------------------------------------===//
608 
609 /// This class provides utilities for constructing predicates.
611 public:
613  : uniquer(uniquer), ctx(ctx) {}
614 
615  //===--------------------------------------------------------------------===//
616  // Positions
617  //===--------------------------------------------------------------------===//
618 
619  /// Returns the root operation position.
621 
622  /// Returns the parent position defining the value held by the given operand.
624  assert((isa<OperandPosition, OperandGroupPosition>(p)) &&
625  "expected operand position");
626  return OperationPosition::get(uniquer, p);
627  }
628 
629  /// Returns the operation position equivalent to the given position.
631  assert((isa<ForEachPosition>(p)) && "expected users position");
632  return OperationPosition::get(uniquer, p);
633  }
634 
635  // Returns a position for a new value created by a constraint.
637  unsigned index) {
638  return ConstraintPosition::get(uniquer, std::make_pair(q, index));
639  }
640 
641  /// Returns an attribute position for an attribute of the given operation.
642  Position *getAttribute(OperationPosition *p, StringRef name) {
643  return AttributePosition::get(uniquer, p, StringAttr::get(ctx, name));
644  }
645 
646  /// Returns an attribute position for the given attribute.
648  return AttributeLiteralPosition::get(uniquer, attr);
649  }
650 
651  Position *getForEach(Position *p, unsigned id) {
652  return ForEachPosition::get(uniquer, p, id);
653  }
654 
655  /// Returns an operand position for an operand of the given operation.
656  Position *getOperand(OperationPosition *p, unsigned operand) {
657  return OperandPosition::get(uniquer, p, operand);
658  }
659 
660  /// Returns a position for a group of operands of the given operation.
661  Position *getOperandGroup(OperationPosition *p, std::optional<unsigned> group,
662  bool isVariadic) {
663  return OperandGroupPosition::get(uniquer, p, group, isVariadic);
664  }
666  return getOperandGroup(p, /*group=*/std::nullopt, /*isVariadic=*/true);
667  }
668 
669  /// Returns a result position for a result of the given operation.
670  Position *getResult(OperationPosition *p, unsigned result) {
671  return ResultPosition::get(uniquer, p, result);
672  }
673 
674  /// Returns a position for a group of results of the given operation.
675  Position *getResultGroup(OperationPosition *p, std::optional<unsigned> group,
676  bool isVariadic) {
677  return ResultGroupPosition::get(uniquer, p, group, isVariadic);
678  }
680  return getResultGroup(p, /*group=*/std::nullopt, /*isVariadic=*/true);
681  }
682 
683  /// Returns a type position for the given entity.
684  Position *getType(Position *p) { return TypePosition::get(uniquer, p); }
685 
686  /// Returns a type position for the given type value. The value is stored
687  /// as either a TypeAttr, or an ArrayAttr of TypeAttr.
689  return TypeLiteralPosition::get(uniquer, attr);
690  }
691 
692  /// Returns the users of a position using the value at the given operand.
693  UsersPosition *getUsers(Position *p, bool useRepresentative) {
695  ResultGroupPosition>(p)) &&
696  "expected result position");
697  return UsersPosition::get(uniquer, p, useRepresentative);
698  }
699 
700  //===--------------------------------------------------------------------===//
701  // Qualifiers
702  //===--------------------------------------------------------------------===//
703 
704  /// An ordinal predicate consists of a "Question" and a set of acceptable
705  /// "Answers" (later converted to ordinal values). A predicate will query some
706  /// property of a positional value and decide what to do based on the result.
707  using Predicate = std::pair<Qualifier *, Qualifier *>;
708 
709  /// Create a predicate comparing an attribute to a known value.
711  return {AttributeQuestion::get(uniquer),
712  AttributeAnswer::get(uniquer, attr)};
713  }
714 
715  /// Create a predicate checking if two values are equal.
717  return {EqualToQuestion::get(uniquer, pos), TrueAnswer::get(uniquer)};
718  }
719 
720  /// Create a predicate checking if two values are not equal.
722  return {EqualToQuestion::get(uniquer, pos), FalseAnswer::get(uniquer)};
723  }
724 
725  /// Create a predicate that applies a generic constraint.
727  ArrayRef<Type> resultTypes, bool isNegated) {
728  return {ConstraintQuestion::get(
729  uniquer, std::make_tuple(name, args, resultTypes, isNegated)),
730  TrueAnswer::get(uniquer)};
731  }
732 
733  /// Create a predicate comparing a value with null.
735  return {IsNotNullQuestion::get(uniquer), TrueAnswer::get(uniquer)};
736  }
737 
738  /// Create a predicate comparing the number of operands of an operation to a
739  /// known value.
740  Predicate getOperandCount(unsigned count) {
741  return {OperandCountQuestion::get(uniquer),
742  UnsignedAnswer::get(uniquer, count)};
743  }
745  return {OperandCountAtLeastQuestion::get(uniquer),
746  UnsignedAnswer::get(uniquer, count)};
747  }
748 
749  /// Create a predicate comparing the name of an operation to a known value.
750  Predicate getOperationName(StringRef name) {
751  return {OperationNameQuestion::get(uniquer),
752  OperationNameAnswer::get(uniquer, OperationName(name, ctx))};
753  }
754 
755  /// Create a predicate comparing the number of results of an operation to a
756  /// known value.
757  Predicate getResultCount(unsigned count) {
758  return {ResultCountQuestion::get(uniquer),
759  UnsignedAnswer::get(uniquer, count)};
760  }
762  return {ResultCountAtLeastQuestion::get(uniquer),
763  UnsignedAnswer::get(uniquer, count)};
764  }
765 
766  /// Create a predicate comparing the type of an attribute or value to a known
767  /// type. The value is stored as either a TypeAttr, or an ArrayAttr of
768  /// TypeAttr.
770  return {TypeQuestion::get(uniquer), TypeAnswer::get(uniquer, type)};
771  }
772 
773 private:
774  /// The uniquer used when allocating predicate nodes.
775  PredicateUniquer &uniquer;
776 
777  /// The current MLIR context.
778  MLIRContext *ctx;
779 };
780 
781 } // namespace pdl_to_pdl_interp
782 } // namespace mlir
783 
784 #endif // MLIR_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_
union mlir::linalg::@1193::ArityGroupAndKind::Kind kind
Attributes are known-constant values of operations.
Definition: Attributes.h:25
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class acts as the base storage that all storage classes must derived from.
This is a utility allocator used to allocate memory for instances of derived types.
ArrayRef< T > copyInto(ArrayRef< T > elements)
Copy the specified array of elements into memory managed by our bump pointer allocator.
T * allocate()
Allocate an instance of the provided type.
A utility class to get or create instances of "storage classes".
Storage * get(function_ref< void(Storage *)> initFn, TypeID id, Args &&...args)
Gets a uniqued instance of 'Storage'.
A position describes a value on the input IR on which a predicate may be applied, such as an operatio...
Definition: Predicate.h:144
Position(Predicates::Kind kind)
Definition: Predicate.h:146
unsigned getOperationDepth() const
Returns the depth of the first ancestor operation position.
Definition: Predicate.cpp:21
Position * getParent() const
Returns the parent position. The root operation position has no parent.
Definition: Predicate.h:153
Predicates::Kind getKind() const
Returns the kind of this position.
Definition: Predicate.h:156
Position * parent
Link to the parent position.
Definition: Predicate.h:160
Base storage for simple predicates that only unique with the kind.
Definition: Predicate.h:119
Base class for all predicates, used to allow efficient pointer comparison.
Definition: Predicate.h:84
bool operator==(const KeyTy &key) const
Utility methods required by the storage allocator.
Definition: Predicate.h:107
static ConcreteT * construct(StorageUniquer::StorageAllocator &alloc, KeyT &&key)
Construct an instance with the given storage allocator.
Definition: Predicate.h:101
static ConcreteT * get(StorageUniquer &uniquer, Args &&...args)
Get an instance of this position.
Definition: Predicate.h:95
static bool classof(const BaseT *pred)
Definition: Predicate.h:108
const KeyTy & getValue() const
Return the key value of this predicate.
Definition: Predicate.h:111
PredicateBase< ConcreteT, BaseT, Key, Kind > Base
Definition: Predicate.h:87
This class provides utilities for constructing predicates.
Definition: Predicate.h:610
ConstraintPosition * getConstraintPosition(ConstraintQuestion *q, unsigned index)
Definition: Predicate.h:636
Position * getTypeLiteral(Attribute attr)
Returns a type position for the given type value.
Definition: Predicate.h:688
Predicate getOperandCount(unsigned count)
Create a predicate comparing the number of operands of an operation to a known value.
Definition: Predicate.h:740
OperationPosition * getPassthroughOp(Position *p)
Returns the operation position equivalent to the given position.
Definition: Predicate.h:630
Predicate getIsNotNull()
Create a predicate comparing a value with null.
Definition: Predicate.h:734
Predicate getOperandCountAtLeast(unsigned count)
Definition: Predicate.h:744
Predicate getResultCountAtLeast(unsigned count)
Definition: Predicate.h:761
Position * getType(Position *p)
Returns a type position for the given entity.
Definition: Predicate.h:684
Position * getAttribute(OperationPosition *p, StringRef name)
Returns an attribute position for an attribute of the given operation.
Definition: Predicate.h:642
Position * getOperandGroup(OperationPosition *p, std::optional< unsigned > group, bool isVariadic)
Returns a position for a group of operands of the given operation.
Definition: Predicate.h:661
Position * getForEach(Position *p, unsigned id)
Definition: Predicate.h:651
Position * getOperand(OperationPosition *p, unsigned operand)
Returns an operand position for an operand of the given operation.
Definition: Predicate.h:656
Position * getResult(OperationPosition *p, unsigned result)
Returns a result position for a result of the given operation.
Definition: Predicate.h:670
Position * getRoot()
Returns the root operation position.
Definition: Predicate.h:620
Predicate getAttributeConstraint(Attribute attr)
Create a predicate comparing an attribute to a known value.
Definition: Predicate.h:710
Position * getResultGroup(OperationPosition *p, std::optional< unsigned > group, bool isVariadic)
Returns a position for a group of results of the given operation.
Definition: Predicate.h:675
Position * getAllResults(OperationPosition *p)
Definition: Predicate.h:679
UsersPosition * getUsers(Position *p, bool useRepresentative)
Returns the users of a position using the value at the given operand.
Definition: Predicate.h:693
Predicate getTypeConstraint(Attribute type)
Create a predicate comparing the type of an attribute or value to a known type.
Definition: Predicate.h:769
OperationPosition * getOperandDefiningOp(Position *p)
Returns the parent position defining the value held by the given operand.
Definition: Predicate.h:623
Predicate getResultCount(unsigned count)
Create a predicate comparing the number of results of an operation to a known value.
Definition: Predicate.h:757
std::pair< Qualifier *, Qualifier * > Predicate
An ordinal predicate consists of a "Question" and a set of acceptable "Answers" (later converted to o...
Definition: Predicate.h:707
Predicate getEqualTo(Position *pos)
Create a predicate checking if two values are equal.
Definition: Predicate.h:716
Position * getAllOperands(OperationPosition *p)
Definition: Predicate.h:665
PredicateBuilder(PredicateUniquer &uniquer, MLIRContext *ctx)
Definition: Predicate.h:612
Position * getAttributeLiteral(Attribute attr)
Returns an attribute position for the given attribute.
Definition: Predicate.h:647
Predicate getConstraint(StringRef name, ArrayRef< Position * > args, ArrayRef< Type > resultTypes, bool isNegated)
Create a predicate that applies a generic constraint.
Definition: Predicate.h:726
Predicate getNotEqualTo(Position *pos)
Create a predicate checking if two values are not equal.
Definition: Predicate.h:721
Predicate getOperationName(StringRef name)
Create a predicate comparing the name of an operation to a known value.
Definition: Predicate.h:750
This class provides a storage uniquer that is used to allocate predicate instances.
Definition: Predicate.h:566
An ordinal predicate consists of a "Question" and a set of acceptable "Answers" (later converted to o...
Definition: Predicate.h:422
Predicates::Kind getKind() const
Returns the kind of this qualifier.
Definition: Predicate.h:427
Qualifier(Predicates::Kind kind)
Definition: Predicate.h:424
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
@ OperationPos
Positions, ordered by decreasing priority.
Definition: Predicate.h:46
inline ::llvm::hash_code hash_value(const PolynomialBase< D, T > &arg)
Definition: Polynomial.h:262
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
An Answer representing an Attribute value.
Definition: Predicate.h:441
A position describing a literal attribute.
Definition: Predicate.h:189
A position describing an attribute of an operation.
Definition: Predicate.h:175
StringAttr getName() const
Returns the attribute name of this position.
Definition: Predicate.h:179
Compare an Attribute to a constant value.
Definition: Predicate.h:485
A position describing the result of a native constraint.
Definition: Predicate.h:301
ConstraintQuestion * getQuestion() const
Returns the ConstraintQuestion to enable keeping track of the native constraint this position stems f...
Definition: Predicate.h:306
Apply a parameterized constraint to multiple position values and possibly produce results.
Definition: Predicate.h:493
StringRef getName() const
Return the name of the constraint.
Definition: Predicate.h:497
ArrayRef< Type > getResultTypes() const
Return the result types of the constraint.
Definition: Predicate.h:503
ArrayRef< Position * > getArgs() const
Return the arguments of the constraint.
Definition: Predicate.h:500
static ConstraintQuestion * construct(StorageUniquer::StorageAllocator &alloc, KeyTy key)
Construct an instance with the given storage allocator.
Definition: Predicate.h:509
static llvm::hash_code hashKey(const KeyTy &key)
Returns a hash suitable for the given keytype.
Definition: Predicate.h:518
bool getIsNegated() const
Return the negation status of the constraint.
Definition: Predicate.h:506
Compare the equality of two values.
Definition: Predicate.h:526
An Answer representing a boolean 'false' value.
Definition: Predicate.h:460
A position describing an iterative choice of an operation.
Definition: Predicate.h:200
unsigned getID() const
Returns the ID, for differentiating various loops.
Definition: Predicate.h:205
Compare a positional value with null, i.e. check if it exists.
Definition: Predicate.h:533
Compare the number of operands of an operation with a known value.
Definition: Predicate.h:538
A position describing an operand group of an operation.
Definition: Predicate.h:232
bool isVariadic() const
Returns if the operand group has unknown size.
Definition: Predicate.h:248
std::optional< unsigned > getOperandGroupNumber() const
Returns the group number of this position.
Definition: Predicate.h:242
static llvm::hash_code hashKey(const KeyTy &key)
Returns a hash suitable for the given keytype.
Definition: Predicate.h:236
A position describing an operand of an operation.
Definition: Predicate.h:216
unsigned getOperandNumber() const
Returns the operand number of this position.
Definition: Predicate.h:220
An Answer representing an OperationName value.
Definition: Predicate.h:448
Compare the name of an operation with a known value.
Definition: Predicate.h:546
An operation position describes an operation node in the IR.
Definition: Predicate.h:259
static OperationPosition * getRoot(StorageUniquer &uniquer)
Gets the root position.
Definition: Predicate.h:270
bool isRoot() const
Returns if this operation position corresponds to the root.
Definition: Predicate.h:283
unsigned getDepth() const
Returns the depth of this position.
Definition: Predicate.h:280
static llvm::hash_code hashKey(const KeyTy &key)
Returns a hash suitable for the given keytype.
Definition: Predicate.h:265
bool isOperandDefiningOp() const
Returns if this operation represents an operand defining op.
Definition: Predicate.cpp:55
static OperationPosition * get(StorageUniquer &uniquer, Position *parent)
Gets an operation position with the given parent.
Definition: Predicate.h:275
Compare the number of results of an operation with a known value.
Definition: Predicate.h:551
A position describing a result group of an operation.
Definition: Predicate.h:336
bool isVariadic() const
Returns if the result group has unknown size.
Definition: Predicate.h:354
std::optional< unsigned > getResultGroupNumber() const
Returns the group number of this position.
Definition: Predicate.h:348
static llvm::hash_code hashKey(const KeyTy &key)
Returns a hash suitable for the given keytype.
Definition: Predicate.h:342
A position describing a result of an operation.
Definition: Predicate.h:320
unsigned getResultNumber() const
Returns the result number of this position.
Definition: Predicate.h:324
An Answer representing a boolean true value.
Definition: Predicate.h:454
An Answer representing a Type value.
Definition: Predicate.h:467
A position describing a literal type or type range.
Definition: Predicate.h:381
A position describing the result type of an entity, i.e.
Definition: Predicate.h:364
Compare the type of an attribute or value with a known type.
Definition: Predicate.h:558
An Answer representing an unsigned value.
Definition: Predicate.h:474
A position describing the users of a value or a range of values.
Definition: Predicate.h:394
bool useRepresentative() const
Indicates whether to compute a range of a representative.
Definition: Predicate.h:403
static llvm::hash_code hashKey(const KeyTy &key)
Returns a hash suitable for the given keytype.
Definition: Predicate.h:398