9 #ifndef MLIR_IR_PATTERNMATCH_H 10 #define MLIR_IR_PATTERNMATCH_H 14 #include "llvm/ADT/FunctionExtras.h" 15 #include "llvm/Support/TypeName.h" 19 class PatternRewriter;
33 enum { ImpossibleToMatchSentinel = 65535 };
49 return representation == rhs.representation;
53 return representation < rhs.representation;
60 unsigned short representation{ImpossibleToMatchSentinel};
93 if (rootKind == RootKind::OperationName)
102 if (rootKind == RootKind::InterfaceID)
111 if (rootKind == RootKind::TraitID)
128 return contextAndHasBoundedRecursion.getInt();
133 return contextAndHasBoundedRecursion.getPointer();
149 debugLabels.append(labels.begin(), labels.end());
201 contextAndHasBoundedRecursion.setInt(hasBoundedRecursionArg);
205 Pattern(
const void *rootValue, RootKind rootKind,
210 const void *rootValue;
218 llvm::PointerIntPair<MLIRContext *, 1, bool> contextAndHasBoundedRecursion;
273 template <
typename T,
typename... Args>
274 static std::unique_ptr<T>
create(Args &&... args) {
275 std::unique_ptr<T> pattern =
276 std::make_unique<T>(std::forward<Args>(args)...);
277 initializePattern<T>(*pattern);
280 if (pattern->getDebugName().empty())
281 pattern->setDebugName(llvm::getTypeName<T>());
291 template <
typename T,
typename... Args>
292 using has_initialize = decltype(std::declval<T>().initialize());
293 template <
typename T>
294 using detect_has_initialize = llvm::is_detected<has_initialize, T>;
297 template <
typename T>
299 initializePattern(T &pattern) {
300 pattern.initialize();
304 template <
typename T>
306 initializePattern(T &) {}
309 virtual void anchor();
316 template <
typename SourceOp>
318 using RewritePattern::RewritePattern;
322 rewrite(cast<SourceOp>(op), rewriter);
325 return match(cast<SourceOp>(op));
329 return matchAndRewrite(cast<SourceOp>(op), rewriter);
335 llvm_unreachable(
"must override rewrite or matchAndRewrite");
338 llvm_unreachable(
"must override match or matchAndRewrite");
354 template <
typename SourceOp>
363 SourceOp::getOperationName(), benefit, context, generatedNames) {}
369 template <
typename SourceOp>
373 : detail::OpOrInterfaceRewritePatternBase<SourceOp>(
381 template <
template <
typename>
class TraitType>
404 virtual void inlineRegionBefore(
Region ®ion,
Region &parent,
406 void inlineRegionBefore(
Region ®ion,
Block *before);
412 virtual void cloneRegionBefore(
Region ®ion,
Region &parent,
417 void cloneRegionBefore(
Region ®ion,
Block *before);
429 llvm::unique_function<
bool(
OpOperand &)
const> functor);
431 llvm::unique_function<
bool(
OpOperand &)
const> functor) {
432 replaceOpWithIf(op, newValues,
nullptr,
441 bool *allUsesReplaced =
nullptr);
450 template <
typename OpTy,
typename... Args>
452 auto newOp = create<OpTy>(op->
getLoc(), std::forward<Args>(args)...);
453 replaceOpWithResultsOfAnotherOp(op, newOp.getOperation());
461 virtual void eraseBlock(
Block *block);
467 virtual void mergeBlocks(
Block *source,
Block *dest,
498 template <
typename CallableT>
500 startRootUpdate(root);
502 finalizeRootUpdate(root);
510 template <
typename CallbackT>
514 return notifyMatchFailure(loc,
520 template <
typename CallbackT>
523 return notifyMatchFailure(op->
getLoc(),
526 template <
typename ArgT>
528 return notifyMatchFailure(std::forward<ArgT>(arg),
531 template <
typename ArgT>
533 return notifyMatchFailure(std::forward<ArgT>(arg), Twine(msg));
623 : value(value.getAsOpaquePointer()), kind(
Kind::
Attribute) {}
628 : value(value.getAsOpaquePointer()), kind(
Kind::
Value) {}
632 template <
typename T>
634 assert(
value &&
"isa<> used on a null value");
635 return kind == getKindOf<T>();
640 template <
typename T,
641 typename ResultT = std::conditional_t<
644 return isa<T>() ? castImpl<T>() : ResultT();
649 template <
typename T>
651 assert(isa<T>() &&
"expected value to be of type `T`");
652 return castImpl<T>();
659 explicit operator bool()
const {
return value; }
665 void print(raw_ostream &os)
const;
668 static void print(raw_ostream &os,
Kind kind);
672 template <
typename...>
674 template <
typename T,
typename... R>
675 struct index_of_t<T, T, R...> : std::integral_constant<size_t, 0> {};
676 template <
typename T,
typename F,
typename... R>
677 struct index_of_t<T, F, R...>
678 : std::integral_constant<size_t, 1 + index_of_t<T, R...>::value> {};
681 template <
typename T>
682 static Kind getKindOf() {
689 template <
typename T>
692 return T::getFromOpaquePointer(
value);
694 template <
typename T>
697 return *
reinterpret_cast<T *
>(
const_cast<void *
>(
value));
699 template <
typename T>
701 return reinterpret_cast<T
>(
const_cast<void *
>(
value));
705 const void *
value{
nullptr};
707 Kind kind{Kind::Attribute};
741 llvm::OwningArrayRef<Type> storage(value.size());
743 allocatedTypeRanges.emplace_back(std::move(storage));
744 typeRanges.push_back(allocatedTypeRanges.back());
745 results.push_back(&typeRanges.back());
748 typeRanges.push_back(value);
749 results.push_back(&typeRanges.back());
752 typeRanges.push_back(value);
753 results.push_back(&typeRanges.back());
763 llvm::OwningArrayRef<Value> storage(value.size());
765 allocatedValueRanges.emplace_back(std::move(storage));
766 valueRanges.push_back(allocatedValueRanges.back());
767 results.push_back(&valueRanges.back());
770 valueRanges.push_back(value);
771 results.push_back(&valueRanges.back());
774 valueRanges.push_back(value);
775 results.push_back(&valueRanges.back());
784 typeRanges.reserve(maxNumResults);
785 valueRanges.reserve(maxNumResults);
806 std::function<LogicalResult(PatternRewriter &, ArrayRef<PDLValue>)>;
812 std::function<void(PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
815 namespace pdl_function_builder {
826 template <
class... T>
862 template <
typename T,
typename Enable =
void>
882 template <
typename T,
typename BaseT>
906 static T processAsArg(
BaseT baseValue);
912 template <
typename T>
919 return errorFn(
"expected a non-null value for argument " + Twine(argIdx) +
920 " of type: " + llvm::getTypeName<T>());
934 template <
typename T,
typename BaseT>
938 BaseT baseValue,
size_t argIdx) {
940 .Case([&](T) {
return success(); })
941 .Default([&](
BaseT) {
942 return errorFn(
"expected argument " + Twine(argIdx) +
943 " to be of type: " + llvm::getTypeName<T>());
949 return baseValue.template cast<T>();
964 template <
typename T>
966 std::enable_if_t<std::is_base_of<Attribute, T>::value>>
984 template <
typename T>
986 static_assert(always_false<T>,
987 "`std::string` arguments require a string copy, use " 988 "`StringRef` for string-like arguments instead");
1003 template <
typename T>
1014 template <
typename T>
1074 template <
typename PDLFnT, std::size_t... I>
1076 std::index_sequence<I...>) {
1077 using FnTraitsT = llvm::function_traits<PDLFnT>;
1079 auto errorFn = [&](
const Twine &msg) {
1083 (
void)std::initializer_list<int>{
1087 verifyAsArg(errorFn, values[I], I)
1096 template <
typename PDLFnT, std::size_t... I>
1098 std::index_sequence<I...>) {
1100 #if LLVM_ENABLE_ABI_BREAKING_CHECKS 1101 using FnTraitsT = llvm::function_traits<PDLFnT>;
1103 llvm::report_fatal_error(msg);
1105 (
void)std::initializer_list<int>{
1107 I + 1>>::verifyAsArg(errorFn, values[I], I))),
1117 template <
typename T>
1121 std::forward<T>(
value));
1125 template <
typename T1,
typename T2>
1127 std::pair<T1, T2> &&pair) {
1133 template <
typename... Ts>
1135 std::tuple<Ts...> &&tuple) {
1136 auto applyFn = [&](
auto &&...args) {
1139 (
void)std::initializer_list<int>{
1142 llvm::apply_tuple(applyFn, std::move(tuple));
1150 template <
typename PDLFnT, std::size_t... I,
1151 typename FnTraitsT = llvm::function_traits<PDLFnT>>
1152 typename FnTraitsT::result_t
1155 std::index_sequence<I...>) {
1158 (
ProcessPDLValue<
typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
1168 template <
typename Constra
intFnT>
1173 return std::forward<ConstraintFnT>(constraintFn);
1177 template <
typename Constra
intFnT>
1179 !std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value,
1182 return [constraintFn = std::forward<ConstraintFnT>(constraintFn)](
1185 auto argIndices = std::make_index_sequence<
1186 llvm::function_traits<ConstraintFnT>::num_args - 1>();
1187 if (
failed(verifyAsArgs<ConstraintFnT>(rewriter, values, argIndices)))
1200 template <
typename PDLFnT, std::size_t... I,
1201 typename FnTraitsT = llvm::function_traits<PDLFnT>>
1205 std::index_sequence<I...>) {
1207 (
ProcessPDLValue<
typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
1212 template <
typename PDLFnT, std::size_t... I,
1213 typename FnTraitsT = llvm::function_traits<PDLFnT>>
1217 std::index_sequence<I...>) {
1220 fn(rewriter, (
ProcessPDLValue<
typename FnTraitsT::template arg_t<I + 1>>::
1221 processAsArg(values[I]))...));
1230 template <
typename RewriteFnT>
1234 return std::forward<RewriteFnT>(rewriteFn);
1238 template <
typename RewriteFnT>
1242 return [rewriteFn = std::forward<RewriteFnT>(rewriteFn)](
1246 std::make_index_sequence<llvm::function_traits<RewriteFnT>::num_args -
1248 assertArgs<RewriteFnT>(rewriter, values, argIndices);
1267 : pdlModule(std::move(pdlModule)) {}
1297 void registerConstraintFunction(StringRef name,
1299 template <
typename Constra
intFnT>
1301 ConstraintFnT &&constraintFn) {
1302 registerConstraintFunction(name,
1304 std::forward<ConstraintFnT>(constraintFn)));
1332 template <
typename RewriteFnT>
1335 std::forward<RewriteFnT>(rewriteFn)));
1340 return constraintFunctions;
1343 return constraintFunctions;
1347 return rewriteFunctions;
1350 return rewriteFunctions;
1355 pdlModule =
nullptr;
1356 constraintFunctions.clear();
1357 rewriteFunctions.clear();
1365 llvm::StringMap<PDLConstraintFunction> constraintFunctions;
1366 llvm::StringMap<PDLRewriteFunction> rewriteFunctions;
1374 using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>;
1381 std::unique_ptr<RewritePattern> pattern)
1382 : context(context) {
1383 nativePatterns.emplace_back(std::move(pattern));
1386 : context(pattern.getModule()->getContext()),
1387 pdlPatterns(std::move(pattern)) {}
1399 nativePatterns.clear();
1400 pdlPatterns.clear();
1410 template <
typename... Ts,
typename ConstructorArg,
1411 typename... ConstructorArgs,
1412 typename = std::enable_if_t<
sizeof...(Ts) != 0>>
1418 (
void)std::initializer_list<int>{
1419 0, (addImpl<Ts>(llvm::None, arg, args...), 0)...};
1426 template <
typename... Ts,
typename ConstructorArg,
1427 typename... ConstructorArgs,
1428 typename = std::enable_if_t<
sizeof...(Ts) != 0>>
1430 ConstructorArg &&arg,
1431 ConstructorArgs &&... args) {
1436 (
void)std::initializer_list<int>{
1437 0, (addImpl<Ts>(debugLabels, arg, args...), 0)...};
1443 template <
typename... Ts>
1445 (
void)std::initializer_list<int>{0, (addImpl<Ts>(), 0)...};
1452 nativePatterns.emplace_back(std::move(pattern));
1459 pdlPatterns.mergeIn(std::move(pattern));
1464 template <
typename OpType>
1474 return implFn(op, rewriter);
1480 add(std::make_unique<FnPattern>(std::move(implFn), getContext()));
1493 template <
typename... Ts,
typename ConstructorArg,
1494 typename... ConstructorArgs,
1495 typename = std::enable_if_t<
sizeof...(Ts) != 0>>
1501 (
void)std::initializer_list<int>{
1502 0, (addImpl<Ts>(llvm::None, arg, args...), 0)...};
1508 template <
typename... Ts>
1510 (
void)std::initializer_list<int>{0, (addImpl<Ts>(), 0)...};
1517 nativePatterns.emplace_back(std::move(pattern));
1524 pdlPatterns.mergeIn(std::move(pattern));
1529 template <
typename OpType>
1536 this->setDebugName(llvm::getTypeName<FnPattern>());
1541 return implFn(op, rewriter);
1547 add(std::make_unique<FnPattern>(std::move(implFn), getContext()));
1554 template <
typename T,
typename... Args>
1557 std::unique_ptr<T> pattern =
1558 RewritePattern::create<T>(std::forward<Args>(args)...);
1559 pattern->addDebugLabels(debugLabels);
1560 nativePatterns.emplace_back(std::move(pattern));
1562 template <
typename T,
typename... Args>
1567 pdlPatterns.mergeIn(T(std::forward<Args>(args)...));
1571 NativePatternListT nativePatterns;
1577 #endif // MLIR_IR_PATTERNMATCH_H T cast() const
Cast this value to type T, asserts if this value is not an instance of T.
bool operator<(const PatternBenefit &rhs) const
TODO: Remove this file when SCCP and integer range analysis have been ported to the new framework...
static T processAsArg(PDLValue pdlValue)
static LogicalResult verifyAsArg(function_ref< LogicalResult(const Twine &)> errorFn, PDLValue pdlValue, size_t argIdx)
static void processAsResult(PatternRewriter &, PDLResultList &results, OperandRange values)
This class contains a list of basic blocks and a link to the parent operation it is attached to...
PDLValue(TypeRange *value)
static std::string diag(llvm::Value &v)
static void processAsResult(PatternRewriter &rewriter, PDLResultList &results, StringRef value)
void push_back(ResultRange value)
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Optional< TypeID > getRootInterfaceID() const
Return the interface ID used to match the root operation of this pattern.
This class acts as a special tag that makes the desire to match any operation that implements a given...
Operation is a basic unit of execution within MLIR.
virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const
Rewrite and Match methods that operate on the SourceOp type.
bool isa() const
Returns true if the type of the held value is T.
ResultT dyn_cast() const
Attempt to dynamically cast this value to type T, returns null if this value is not an instance of T...
RewritePatternSet & insert(std::unique_ptr< RewritePattern > pattern)
Add the given native pattern to the pattern list.
void push_back(OperandRange value)
const llvm::StringMap< PDLRewriteFunction > & getRewriteFunctions() const
Return the set of the registered rewrite functions.
SmallVector< PDLValue > results
The PDL results held by this list.
Block represents an ordered list of Operations.
static LogicalResult verifyAsArg(function_ref< LogicalResult(const Twine &)> errorFn, BaseT value, size_t argIdx)
Explicitly add the expected parent API to ensure the parent class implements the necessary API (and d...
PDLResultList(unsigned maxNumResults)
Create a new result list with the expected number of results.
IRRewriter(const OpBuilder &builder)
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
This struct provides a simplified model for processing types that are based on another type...
RewritePatternSet & add(PDLPatternModule &&pattern)
Add the given PDL pattern to the pattern list.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
RewritePatternSet & insert(PDLPatternModule &&pattern)
Add the given PDL pattern to the pattern list.
BlockListType::iterator iterator
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
This class implements the result iterators for the Operation class.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
void push_back(Operation *value)
Push a new Operation onto the result list.
LogicalResult match(Operation *op) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
static void processAsResult(PatternRewriter &rewriter, PDLResultList &results, StringRef value)
LogicalResult verifyAsArgs(PatternRewriter &rewriter, ArrayRef< PDLValue > values, std::index_sequence< I... >)
Validate the given PDLValues match the constraints defined by the argument types of the given functio...
llvm::StringMap< PDLRewriteFunction > takeRewriteFunctions()
PatternBenefit & operator=(const PatternBenefit &)=default
virtual void startRootUpdate(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
LogicalResult notifyMatchFailure(ArgT &&arg, const char *msg)
const void * getAsOpaquePointer() const
Get an opaque pointer to the value.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
void rewrite(Operation *op, PatternRewriter &rewriter) const final
Wrappers around the RewritePattern methods that pass the derived op type.
static StringRef processAsArg(StringAttr value)
This class represents a listener that may be used to hook into various actions within an OpBuilder...
RewritePatternSet(MLIRContext *context, std::unique_ptr< RewritePattern > pattern)
Construct a RewritePatternSet populated with the given pattern.
PDLValue(Attribute value)
FnTraitsT::result_t processArgsAndInvokeConstraint(PDLFnT &fn, PatternRewriter &rewriter, ArrayRef< PDLValue > values, std::index_sequence< I... >)
Process the arguments of a native constraint and invoke it.
static constexpr const bool value
void registerRewriteFunction(StringRef name, RewriteFnT &&rewriteFn)
This class provides an efficient unique identifier for a specific C++ type.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void addDebugLabels(StringRef label)
std::function< void(PatternRewriter &, PDLResultList &, ArrayRef< PDLValue >)> PDLRewriteFunction
A native PDL rewrite function.
static OperationName getFromOpaquePointer(const void *pointer)
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
NativePatternListT & getNativePatterns()
Return the native patterns held in this list.
static T processAsArg(Operation *value)
llvm::StringMap< PDLConstraintFunction > takeConstraintFunctions()
RewritePattern is the common base class for all DAG to DAG replacements.
void setHasBoundedRewriteRecursion(bool hasBoundedRecursionArg=true)
Set the flag detailing if this pattern has bounded rewrite recursion or not.
void push_back(TypeRange value)
Push a new TypeRange onto the result list.
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
ModuleOp getModule()
Return the internal PDL module of this pattern.
static std::unique_ptr< T > create(Args &&... args)
This method provides a convenient interface for creating and initializing derived rewrite patterns of...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
std::function< LogicalResult(PatternRewriter &, ArrayRef< PDLValue >)> PDLConstraintFunction
A generic PDL pattern constraint function.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine...
PDLValue(ValueRange *value)
This class represents an efficient way to signal success or failure.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
static PatternBenefit impossibleToMatch()
bool operator<=(const PatternBenefit &rhs) const
OpListType::iterator iterator
This class contains all of the necessary data for a set of PDL patterns, or pattern rewrites specifie...
virtual LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback)
Notify the rewriter that the pattern failed to match the given operation, and provide a callback to p...
static void processAsResult(PatternRewriter &, PDLResultList &results, ValueTypeRange< ResultRange > types)
static T processAsArg(PDLValue pdlValue)
Storage type of byte-code interpreter values.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
Attributes are known-constant values of operations.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
void push_back(Attribute value)
Push a new Attribute value onto the result list.
Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={})
Construct a pattern with a certain benefit that matches the operation with the given root name...
void assertArgs(PatternRewriter &rewriter, ArrayRef< PDLValue > values, std::index_sequence< I... >)
Assert that the given PDLValues match the constraints defined by the arguments of the given function...
void clear()
Clear out the patterns and functions within this module.
Optional< TypeID > getRootTraitID() const
Return the trait ID used to match the root operation of this pattern.
ArrayRef< StringRef > getDebugLabels() const
Return the set of debug labels attached to this pattern.
void push_back(ValueRange value)
Push a new ValueRange onto the result list.
bool operator>=(const PatternBenefit &rhs) const
ArrayRef< OperationName > getGeneratedOps() const
Return a list of operations that may be generated when rewriting an operation instance with this patt...
void setDebugName(StringRef name)
Set the human readable debug name used for this pattern.
This class provides an abstraction over the various different ranges of value types.
void push_back(Type value)
Push a new Type onto the result list.
virtual void notifyRootReplaced(Operation *op)
These are the callback methods that subclasses can choose to implement if they would like to be notif...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
std::enable_if_t<!std::is_convertible< RewriteFnT, PDLRewriteFunction >::value, PDLRewriteFunction > buildRewriteFn(RewriteFnT &&rewriteFn)
Otherwise, we generate a wrapper that will unpack the PDLValues in the form we desire.
Location getLoc()
The source location the operation was defined or derived from.
static void rewrite(SCCPAnalysis &analysis, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
This class acts as a special tag that makes the desire to match "any" operation type explicit...
RewriterBase(MLIRContext *ctx)
Initialize the builder with this rewriter as the listener.
static LogicalResult verifyAsArg(function_ref< LogicalResult(const Twine &)> errorFn, BaseT baseValue, size_t argIdx)
virtual LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
bool hasBoundedRewriteRecursion() const
Returns true if this pattern is known to result in recursive application, i.e.
RewritePatternSet & add()
Add an instance of each of the pattern types 'Ts'.
const llvm::StringMap< PDLConstraintFunction > & getConstraintFunctions() const
Return the set of the registered constraint functions.
std::enable_if_t<!std::is_same< typename FnTraitsT::result_t, void >::value > processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter, PDLResultList &results, ArrayRef< PDLValue > values, std::index_sequence< I... >)
This overload handles the case of return values, which need to be packaged into the result list...
std::enable_if_t< std::is_convertible< RewriteFnT, PDLRewriteFunction >::value, PDLRewriteFunction > buildRewriteFn(RewriteFnT &&rewriteFn)
Build a rewrite function from the given function RewriteFnT.
virtual void finalizeRootUpdate(Operation *op)
This method is used to signal the end of a root update on the given operation.
Kind
The underlying kind of a PDL value.
RewriterBase(const OpBuilder &otherBuilder)
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
LogicalResult notifyMatchFailure(ArgT &&arg, const Twine &msg)
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
bool operator>(const PatternBenefit &rhs) const
This struct provides a simplified model for processing types that have "builtin" PDLValue support: ...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
RewritePatternSet & add(std::unique_ptr< RewritePattern > pattern)
Add the given native pattern to the pattern list.
SmallVector< TypeRange > typeRanges
Memory used to store ranges held by the list.
SmallVector< ValueRange > valueRanges
PDLPatternModule(OwningOpRef< ModuleOp > pdlModule)
Construct a PDL pattern with the given module.
RewritePatternSet(MLIRContext *context)
This class implements iteration on the types of a given range of values.
static void processAsResult(PatternRewriter &, PDLResultList &results, T value)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
static TypeID getFromOpaquePointer(const void *pointer)
bool isImpossibleToMatch() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
static void processResults(PatternRewriter &rewriter, PDLResultList &results, std::tuple< Ts... > &&tuple)
Store a std::tuple<> as individual results within the result list.
void addDebugLabels(ArrayRef< StringRef > labels)
Add the provided debug labels to this pattern.
Kind getKind() const
Return the kind of this value.
virtual LogicalResult matchAndRewrite(SourceOp op, PatternRewriter &rewriter) const
virtual LogicalResult match(SourceOp op) const
RewritePatternSet & add(LogicalResult(*implFn)(OpType, PatternRewriter &rewriter))
IRRewriter(MLIRContext *ctx)
MLIRContext is the top-level object for a collection of MLIR operations.
bool operator==(const PatternBenefit &rhs) const
This class represents an operand of an operation.
void clear()
Clear out all of the held patterns in this list.
OpOrInterfaceRewritePatternBase is a wrapper around RewritePattern that allows for matching and rewri...
The class represents a list of PDL results, returned by a native rewrite method.
This class implements the operand iterators for the Operation class.
void replaceOpWithIf(Operation *op, ValueRange newValues, llvm::unique_function< bool(OpOperand &) const > functor)
PDLValue(Operation *value)
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure...
This class acts as a special tag that makes the desire to match any operation that implements a given...
This struct provides a convenient way to determine how to process a given type as either a PDL parame...
void push_back(ValueTypeRange< OperandRange > value)
void push_back(ValueTypeRange< ResultRange > value)
PDLPatternModule & getPDLPatterns()
Return the PDL patterns held in this list.
static void processAsResult(PatternRewriter &, PDLResultList &results, ResultRange values)
virtual void notifyOperationRemoved(Operation *op)
This is called on an operation that a rewrite is removing, right before the operation is deleted...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Operation *op, CallbackT &&reasonCallback)
MLIRContext * getContext() const
Return the MLIRContext used to create this pattern.
RewritePatternSet(PDLPatternModule &&pattern)
This struct provides a simplified model for processing types that inherit from builtin PDLValue types...
void push_back(Value value)
Push a new Value onto the result list.
void registerConstraintFunction(StringRef name, ConstraintFnT &&constraintFn)
static void processAsResult(PatternRewriter &, PDLResultList &results, T value)
static LogicalResult verifyAsArg(function_ref< LogicalResult(const Twine &)> errorFn, PDLValue pdlValue, size_t argIdx)
StringRef getDebugName() const
Return a readable name for this pattern.
RewritePatternSet & addWithLabel(ArrayRef< StringRef > debugLabels, ConstructorArg &&arg, ConstructorArgs &&... args)
An overload of the above add method that allows for attaching a set of debug labels to the attached p...
static T processAsArg(BaseT baseValue)
PDLValue(std::nullptr_t=nullptr)
This class helps build Operations.
This class provides an abstraction over the different types of ranges over Values.
SmallVector< llvm::OwningArrayRef< Type > > allocatedTypeRanges
Memory allocated to store ranges in the result list whose lifetime was generated in the native functi...
bool operator!=(const PatternBenefit &rhs) const
virtual void cancelRootUpdate(Operation *op)
This method cancels a pending root update.
void print(raw_ostream &os) const
Print this value to the provided output stream.
RewritePatternSet & insert(LogicalResult(*implFn)(OpType, PatternRewriter &rewriter))
constexpr bool always_false
A utility variable that always resolves to false.
StringAttr getStringAttr(const Twine &bytes)
SmallVector< llvm::OwningArrayRef< Value > > allocatedValueRanges
RewritePatternSet & insert()
Add an instance of each of the pattern types 'Ts'.
static void processAsResult(PatternRewriter &, PDLResultList &results, ValueTypeRange< OperandRange > types)
PatternBenefit getBenefit() const
Return the benefit (the inverse of "cost") of matching this pattern.
MLIRContext * getContext() const
std::enable_if_t< !std::is_convertible< ConstraintFnT, PDLConstraintFunction >::value, PDLConstraintFunction > buildConstraintFn(ConstraintFnT &&constraintFn)
Otherwise, we generate a wrapper that will unpack the PDLValues in the form we desire.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t< std::is_convertible< ConstraintFnT, PDLConstraintFunction >::value, PDLConstraintFunction > buildConstraintFn(ConstraintFnT &&constraintFn)
Build a constraint function from the given function ConstraintFnT.
static std::string processAsArg(T value)
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments...
Optional< OperationName > getRootKind() const
Return the root node that this pattern matches.