25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ADT/SmallBitVector.h"
31#define GEN_PASS_DEF_EXPANDSTRIDEDMETADATAPASS
32#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
41struct StridedMetadata {
44 SmallVector<OpFoldResult> sizes;
45 SmallVector<OpFoldResult> strides;
59static FailureOr<StridedMetadata>
61 memref::SubViewOp subview) {
64 Value source = subview.getSource();
65 auto sourceType = cast<MemRefType>(source.
getType());
66 unsigned sourceRank = sourceType.getRank();
68 auto newExtractStridedMetadata =
69 memref::ExtractStridedMetadataOp::create(rewriter, origLoc, source);
71 auto [sourceStrides, sourceOffset] = sourceType.getStridesAndOffset();
73 auto [resultStrides, resultOffset] = subview.getType().getStridesAndOffset();
81 auto origStrides = newExtractStridedMetadata.getStrides();
89 values[0] = ShapedType::isDynamic(sourceOffset)
91 : rewriter.getIndexAttr(sourceOffset);
96 for (
unsigned i = 0; i < sourceRank; ++i) {
99 ShapedType::isDynamic(sourceStrides[i])
103 rewriter, origLoc, s0 * s1, {subStrides[i], origStride}));
106 unsigned baseIdxForDim = 1 + 2 * i;
107 unsigned subOffsetForDim = baseIdxForDim;
108 unsigned origStrideForDim = baseIdxForDim + 1;
109 expr = expr + symbols[subOffsetForDim] * symbols[origStrideForDim];
110 values[subOffsetForDim] = subOffsets[i];
111 values[origStrideForDim] = origStride;
121 if (computedOffset && ShapedType::isStatic(resultOffset))
122 assert(*computedOffset == resultOffset &&
123 "mismatch between computed offset and result type offset");
129 auto subType = cast<MemRefType>(subview.getType());
130 unsigned subRank = subType.getRank();
139 llvm::SmallBitVector droppedDims = subview.getDroppedDims();
142 finalSizes.reserve(subRank);
145 finalStrides.reserve(subRank);
151 for (
unsigned i = 0; i < sourceRank; ++i) {
152 if (droppedDims.test(i))
155 finalSizes.push_back(subSizes[i]);
156 finalStrides.push_back(strides[i]);
161 if (computedStride && ShapedType::isStatic(resultStrides[
j]))
162 assert(*computedStride == resultStrides[
j] &&
163 "mismatch between computed stride and result type stride");
167 assert(finalSizes.size() == subRank &&
168 "Should have populated all the values at this point");
169 return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), finalOffset,
170 finalSizes, finalStrides};
189 using OpRewritePattern<memref::SubViewOp>::OpRewritePattern;
191 LogicalResult matchAndRewrite(memref::SubViewOp subview,
192 PatternRewriter &rewriter)
const override {
193 FailureOr<StridedMetadata> stridedMetadata =
194 resolveSubviewStridedMetadata(rewriter, subview);
195 if (
failed(stridedMetadata)) {
197 "failed to resolve subview metadata");
201 subview, subview.getType(), stridedMetadata->basePtr,
202 stridedMetadata->offset, stridedMetadata->sizes,
203 stridedMetadata->strides);
221struct ExtractStridedMetadataOpSubviewFolder
226 PatternRewriter &rewriter)
const override {
227 auto subviewOp = op.getSource().getDefiningOp<memref::SubViewOp>();
231 FailureOr<StridedMetadata> stridedMetadata =
232 resolveSubviewStridedMetadata(rewriter, subviewOp);
233 if (
failed(stridedMetadata)) {
235 op,
"failed to resolve metadata in terms of source subview op");
237 Location loc = subviewOp.getLoc();
238 SmallVector<Value> results;
239 results.reserve(subviewOp.getType().getRank() * 2 + 2);
240 results.push_back(stridedMetadata->basePtr);
242 stridedMetadata->offset));
246 stridedMetadata->strides));
266getExpandedSizes(memref::ExpandShapeOp expandShape,
OpBuilder &builder,
269 expandShape.getReassociationIndices()[groupId];
270 assert(!reassocGroup.empty() &&
271 "Reassociation group should have at least one dimension");
275 for (
auto index : reassocGroup)
276 expandedSizes.push_back(outputShape[
index]);
278 return expandedSizes;
307 expandShape.getReassociationIndices()[groupId];
308 assert(!reassocGroup.empty() &&
309 "Reassociation group should have at least one dimension");
311 unsigned groupSize = reassocGroup.size();
312 Location loc = expandShape.getLoc();
321 Value source = expandShape.getSrc();
322 auto sourceType = cast<MemRefType>(source.
getType());
323 auto [strides, offset] = sourceType.getStridesAndOffset();
325 OpFoldResult origStride = ShapedType::isDynamic(strides[groupId])
326 ? origStrides[groupId]
333 for (
int i = groupSize - 1; i >= 0; --i) {
334 expandedStrides[i] = currentStride;
335 currentStride =
mul(currentStride, outputShape[reassocGroup[i]]);
338 return expandedStrides;
355 unsigned numberOfSymbols = 0;
356 unsigned groupSize =
indices.size();
357 for (
unsigned i = 0; i < groupSize; ++i) {
361 int64_t maybeConstant = maybeConstants[srcIdx];
363 inputValues.push_back(isDynamic(maybeConstant)
387getCollapsedSize(memref::CollapseShapeOp collapseShape,
OpBuilder &builder,
391 MemRefType collapseShapeType = collapseShape.getResultType();
393 uint64_t size = collapseShapeType.getDimSize(groupId);
394 if (ShapedType::isStatic(size)) {
396 return collapsedSize;
402 Value source = collapseShape.getSrc();
403 auto sourceType = cast<MemRefType>(source.
getType());
406 collapseShape.getReassociationIndices()[groupId];
408 collapsedSize.push_back(getProductOfValues(
409 reassocGroup, builder, collapseShape.getLoc(), sourceType.getShape(),
410 origSizes, ShapedType::isDynamic));
412 return collapsedSize;
428getCollapsedStride(memref::CollapseShapeOp collapseShape,
OpBuilder &builder,
432 collapseShape.getReassociationIndices()[groupId];
433 assert(!reassocGroup.empty() &&
434 "Reassociation group should have at least one dimension");
436 Value source = collapseShape.getSrc();
437 auto sourceType = cast<MemRefType>(source.
getType());
439 auto [strides, offset] = sourceType.getStridesAndOffset();
444 for (
int64_t currentDim : reassocGroup) {
450 if (srcShape[currentDim] == 1)
453 int64_t currentStride = strides[currentDim];
454 lastValidStride = ShapedType::isDynamic(currentStride)
455 ? origStrides[currentDim]
458 if (!lastValidStride) {
461 MemRefType collapsedType = collapseShape.getResultType();
462 auto [collapsedStrides, collapsedOffset] =
463 collapsedType.getStridesAndOffset();
464 int64_t finalStride = collapsedStrides[groupId];
465 if (ShapedType::isDynamic(finalStride)) {
468 for (
int64_t currentDim : reassocGroup) {
469 assert(srcShape[currentDim] == 1 &&
470 "We should be dealing with 1x1x...x1");
472 if (ShapedType::isDynamic(strides[currentDim]))
473 return {origStrides[currentDim]};
475 llvm_unreachable(
"We should have found a dynamic stride");
480 return {lastValidStride};
493template <
typename ReassociativeReshapeLikeOp>
494static FailureOr<StridedMetadata> resolveReshapeStridedMetadata(
495 RewriterBase &rewriter, ReassociativeReshapeLikeOp reshape,
504 getReshapedStrides) {
507 Location origLoc = reshape.getLoc();
508 Value source = reshape.getSrc();
509 auto sourceType = cast<MemRefType>(source.
getType());
510 unsigned sourceRank = sourceType.getRank();
512 auto newExtractStridedMetadata =
513 memref::ExtractStridedMetadataOp::create(rewriter, origLoc, source);
516 auto [strides, offset] = sourceType.getStridesAndOffset();
517 MemRefType reshapeType = reshape.getResultType();
518 unsigned reshapeRank = reshapeType.getRank();
521 ShapedType::isDynamic(offset)
523 : rewriter.getIndexAttr(offset);
526 if (sourceRank == 0) {
528 return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr,
533 finalSizes.reserve(reshapeRank);
535 finalStrides.reserve(reshapeRank);
542 unsigned idx = 0, endIdx = reshape.getReassociationIndices().size();
543 for (; idx != endIdx; ++idx) {
545 getReshapedSizes(reshape, rewriter, origSizes, idx);
547 reshape, rewriter, origSizes, origStrides, idx);
549 unsigned groupSize = reshapedSizes.size();
550 for (
unsigned i = 0; i < groupSize; ++i) {
551 finalSizes.push_back(reshapedSizes[i]);
552 finalStrides.push_back(reshapedStrides[i]);
555 assert(((isa<memref::ExpandShapeOp>(reshape) && idx == sourceRank) ||
556 (isa<memref::CollapseShapeOp>(reshape) && idx == reshapeRank)) &&
557 "We should have visited all the input dimensions");
558 assert(finalSizes.size() == reshapeRank &&
559 "We should have populated all the values");
561 return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr,
562 finalSizes, finalStrides};
581template <
typename ReassociativeReshapeLikeOp,
593 LogicalResult matchAndRewrite(ReassociativeReshapeLikeOp reshape,
595 FailureOr<StridedMetadata> stridedMetadata =
596 resolveReshapeStridedMetadata<ReassociativeReshapeLikeOp>(
597 rewriter, reshape, getReshapedSizes, getReshapedStrides);
598 if (
failed(stridedMetadata)) {
600 "failed to resolve reshape metadata");
604 reshape, reshape.getType(), stridedMetadata->basePtr,
605 stridedMetadata->offset, stridedMetadata->sizes,
606 stridedMetadata->strides);
624struct ExtractStridedMetadataOpCollapseShapeFolder
628 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
630 auto collapseShapeOp =
631 op.getSource().getDefiningOp<memref::CollapseShapeOp>();
632 if (!collapseShapeOp)
635 FailureOr<StridedMetadata> stridedMetadata =
636 resolveReshapeStridedMetadata<memref::CollapseShapeOp>(
637 rewriter, collapseShapeOp, getCollapsedSize, getCollapsedStride);
638 if (
failed(stridedMetadata)) {
641 "failed to resolve metadata in terms of source collapse_shape op");
644 Location loc = collapseShapeOp.getLoc();
646 results.push_back(stridedMetadata->basePtr);
648 stridedMetadata->offset));
652 stridedMetadata->strides));
661struct ExtractStridedMetadataOpExpandShapeFolder
665 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
667 auto expandShapeOp = op.getSource().getDefiningOp<memref::ExpandShapeOp>();
671 FailureOr<StridedMetadata> stridedMetadata =
672 resolveReshapeStridedMetadata<memref::ExpandShapeOp>(
673 rewriter, expandShapeOp, getExpandedSizes, getExpandedStrides);
674 if (
failed(stridedMetadata)) {
676 op,
"failed to resolve metadata in terms of source expand_shape op");
679 Location loc = expandShapeOp.getLoc();
681 results.push_back(stridedMetadata->basePtr);
683 stridedMetadata->offset));
687 stridedMetadata->strides));
707template <
typename AllocLikeOp>
708struct ExtractStridedMetadataOpAllocFolder
713 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
715 auto allocLikeOp = op.getSource().getDefiningOp<AllocLikeOp>();
719 auto memRefType = cast<MemRefType>(allocLikeOp.getResult().getType());
720 if (!memRefType.getLayout().isIdentity())
722 allocLikeOp,
"alloc-like operations should have been normalized");
725 int rank = memRefType.getRank();
728 ValueRange dynamic = allocLikeOp.getDynamicSizes();
731 unsigned dynamicPos = 0;
732 for (
int64_t size : memRefType.getShape()) {
733 if (ShapedType::isDynamic(size))
734 sizes.push_back(dynamic[dynamicPos++]);
742 unsigned symbolNumber = 0;
743 for (
int i = rank - 2; i >= 0; --i) {
745 assert(i + 1 + symbolNumber == sizes.size() &&
746 "The ArrayRef should encompass the last #symbolNumber sizes");
749 sizesInvolvedInStride);
754 results.reserve(rank * 2 + 2);
756 auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType());
758 if (op.getBaseBuffer().use_empty()) {
759 results.push_back(
nullptr);
761 if (allocLikeOp.getType() == baseBufferType)
762 results.push_back(allocLikeOp);
764 results.push_back(memref::ReinterpretCastOp::create(
765 rewriter, loc, baseBufferType, allocLikeOp, offset,
798struct ExtractStridedMetadataOpGetGlobalFolder
801 using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern;
804 PatternRewriter &rewriter)
const override {
805 auto getGlobalOp = op.getSource().getDefiningOp<memref::GetGlobalOp>();
809 auto memRefType = cast<MemRefType>(getGlobalOp.getResult().getType());
810 if (!memRefType.getLayout().isIdentity()) {
813 "get-global operation result should have been normalized");
816 Location loc = op.getLoc();
817 int rank = memRefType.getRank();
820 ArrayRef<int64_t> sizes = memRefType.getShape();
821 assert(!llvm::any_of(sizes, ShapedType::isDynamic) &&
822 "unexpected dynamic shape for result of `memref.get_global` op");
828 SmallVector<Value> results;
829 results.reserve(rank * 2 + 2);
831 auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType());
833 if (getGlobalOp.getType() == baseBufferType)
834 results.push_back(getGlobalOp);
836 results.push_back(memref::ReinterpretCastOp::create(
837 rewriter, loc, baseBufferType, getGlobalOp, offset,
839 ArrayRef<int64_t>()));
844 for (
auto size : sizes)
847 for (
auto stride : strides)
866struct ExtractStridedMetadataOpAssumeAlignmentFolder
869 using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern;
871 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
872 PatternRewriter &rewriter)
const override {
873 auto assumeAlignmentOp =
874 op.getSource().getDefiningOp<memref::AssumeAlignmentOp>();
875 if (!assumeAlignmentOp)
879 op, assumeAlignmentOp.getViewSource());
886class RewriteExtractAlignedPointerAsIndexOfViewLikeOp
891 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
892 PatternRewriter &rewriter)
const override {
894 extractOp.getSource().getDefiningOp<ViewLikeOpInterface>();
898 if (!viewLikeOp || extractOp.getSource() != viewLikeOp.getViewDest() ||
899 !isa<memref::SubViewOp, memref::ReinterpretCastOp>(viewLikeOp))
902 extractOp.getSourceMutable().assign(viewLikeOp.getViewSource());
921class ExtractStridedMetadataOpReinterpretCastFolder
926 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
927 PatternRewriter &rewriter)
const override {
928 auto reinterpretCastOp = extractStridedMetadataOp.getSource()
929 .getDefiningOp<memref::ReinterpretCastOp>();
930 if (!reinterpretCastOp)
933 Location loc = extractStridedMetadataOp.getLoc();
935 SmallVector<Type> inferredReturnTypes;
936 if (
failed(extractStridedMetadataOp.inferReturnTypes(
937 rewriter.
getContext(), loc, {reinterpretCastOp.getSource()},
939 inferredReturnTypes)))
941 reinterpretCastOp,
"reinterpret_cast source's type is incompatible");
943 auto memrefType = cast<MemRefType>(reinterpretCastOp.getResult().getType());
944 unsigned rank = memrefType.getRank();
945 SmallVector<OpFoldResult> results;
946 results.resize_for_overwrite(rank * 2 + 2);
948 auto newExtractStridedMetadata = memref::ExtractStridedMetadataOp::create(
949 rewriter, loc, reinterpretCastOp.getSource());
952 results[0] = newExtractStridedMetadata.getBaseBuffer();
956 rewriter, loc, reinterpretCastOp.getMixedOffsets()[0]);
958 const unsigned sizeStartIdx = 2;
959 const unsigned strideStartIdx = sizeStartIdx + rank;
961 SmallVector<OpFoldResult> sizes = reinterpretCastOp.getMixedSizes();
962 SmallVector<OpFoldResult> strides = reinterpretCastOp.getMixedStrides();
963 for (
unsigned i = 0; i < rank; ++i) {
964 results[sizeStartIdx + i] = sizes[i];
965 results[strideStartIdx + i] = strides[i];
967 rewriter.
replaceOp(extractStridedMetadataOp,
983class ExtractStridedMetadataOpMemorySpaceCastFolder
988 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
989 PatternRewriter &rewriter)
const override {
990 Location loc = extractStridedMetadataOp.getLoc();
991 Value source = extractStridedMetadataOp.getSource();
992 auto memSpaceCastOp = source.
getDefiningOp<memref::MemorySpaceCastOp>();
995 auto newExtractStridedMetadata = memref::ExtractStridedMetadataOp::create(
996 rewriter, loc, memSpaceCastOp.getSource());
997 SmallVector<Value> results(newExtractStridedMetadata.getResults());
1004 if (!extractStridedMetadataOp.getBaseBuffer().use_empty()) {
1005 auto baseBuffer = results[0];
1006 auto baseBufferType = cast<MemRefType>(baseBuffer.getType());
1007 MemRefType::Builder newTypeBuilder(baseBufferType);
1008 newTypeBuilder.setMemorySpace(
1009 memSpaceCastOp.getResult().getType().getMemorySpace());
1010 results[0] = memref::MemorySpaceCastOp::create(
1011 rewriter, loc, Type{newTypeBuilder}, baseBuffer);
1013 results[0] =
nullptr;
1015 rewriter.
replaceOp(extractStridedMetadataOp, results);
1027class ExtractStridedMetadataOpExtractStridedMetadataFolder
1032 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1033 PatternRewriter &rewriter)
const override {
1034 auto sourceExtractStridedMetadataOp =
1035 extractStridedMetadataOp.getSource()
1036 .getDefiningOp<memref::ExtractStridedMetadataOp>();
1037 if (!sourceExtractStridedMetadataOp)
1039 Location loc = extractStridedMetadataOp.getLoc();
1040 rewriter.
replaceOp(extractStridedMetadataOp,
1041 {sourceExtractStridedMetadataOp.getBaseBuffer(),
1051 patterns.
add<SubviewFolder,
1052 ReshapeFolder<memref::ExpandShapeOp, getExpandedSizes,
1053 getExpandedStrides>,
1054 ReshapeFolder<memref::CollapseShapeOp, getCollapsedSize,
1055 getCollapsedStride>,
1056 ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
1057 ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
1058 ExtractStridedMetadataOpCollapseShapeFolder,
1059 ExtractStridedMetadataOpExpandShapeFolder,
1060 ExtractStridedMetadataOpGetGlobalFolder,
1061 RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
1062 ExtractStridedMetadataOpReinterpretCastFolder,
1063 ExtractStridedMetadataOpSubviewFolder,
1064 ExtractStridedMetadataOpMemorySpaceCastFolder,
1065 ExtractStridedMetadataOpAssumeAlignmentFolder,
1066 ExtractStridedMetadataOpExtractStridedMetadataFolder>(
1072 patterns.
add<ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
1073 ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
1074 ExtractStridedMetadataOpCollapseShapeFolder,
1075 ExtractStridedMetadataOpExpandShapeFolder,
1076 ExtractStridedMetadataOpGetGlobalFolder,
1077 ExtractStridedMetadataOpSubviewFolder,
1078 RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
1079 ExtractStridedMetadataOpReinterpretCastFolder,
1080 ExtractStridedMetadataOpMemorySpaceCastFolder,
1081 ExtractStridedMetadataOpAssumeAlignmentFolder,
1082 ExtractStridedMetadataOpExtractStridedMetadataFolder>(
1092struct ExpandStridedMetadataPass final
1093 :
public memref::impl::ExpandStridedMetadataPassBase<
1094 ExpandStridedMetadataPass> {
1095 void runOnOperation()
override;
1100void ExpandStridedMetadataPass::runOnOperation() {
Base type for affine expression.
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineSymbolExpr(unsigned position)
AffineExpr getAffineConstantExpr(int64_t constant)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents a single result from folding an operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
void populateResolveExtractStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for resolving memref.extract_strided_metadata into memref.extract_strided_metadata o...
void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for expanding memref operations that modify the metadata (sizes, offset,...
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
llvm::function_ref< Fn > function_ref
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.