MLIR 22.0.0git
FoldMemRefsOps.cpp
Go to the documentation of this file.
1//===- FoldSubviewOps.cpp - AMDGPU fold subview ops -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
16#include "llvm/ADT/TypeSwitch.h"
17
18namespace mlir::amdgpu {
19#define GEN_PASS_DEF_AMDGPUFOLDMEMREFOPSPASS
20#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
21
30
31static LogicalResult foldMemrefViewOp(PatternRewriter &rewriter, Location loc,
33 SmallVectorImpl<Value> &resolvedIndices,
34 Value &memrefBase, StringRef role) {
35 Operation *defOp = view.getDefiningOp();
36 if (!defOp) {
37 return failure();
38 }
40 .Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
42 rewriter, loc, subviewOp.getMixedOffsets(),
43 subviewOp.getMixedStrides(), subviewOp.getDroppedDims(), indices,
44 resolvedIndices);
45 memrefBase = subviewOp.getSource();
46 return success();
47 })
48 .Case<memref::ExpandShapeOp>([&](memref::ExpandShapeOp expandShapeOp) {
50 loc, rewriter, expandShapeOp, indices, resolvedIndices,
51 false))) {
52 return failure();
53 }
54 memrefBase = expandShapeOp.getViewSource();
55 return success();
56 })
57 .Case<memref::CollapseShapeOp>(
58 [&](memref::CollapseShapeOp collapseShapeOp) {
60 loc, rewriter, collapseShapeOp, indices,
61 resolvedIndices))) {
62 return failure();
63 }
64 memrefBase = collapseShapeOp.getViewSource();
65 return success();
66 })
67 .Default([&](Operation *op) {
68 return rewriter.notifyMatchFailure(
69 op, (role + " producer is not one of SubViewOp, ExpandShapeOp, or "
70 "CollapseShapeOp")
71 .str());
72 });
73}
74
75struct FoldMemRefOpsIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
77 LogicalResult matchAndRewrite(GatherToLDSOp op,
78 PatternRewriter &rewriter) const override {
79 Location loc = op.getLoc();
80
81 SmallVector<Value> sourceIndices, destIndices;
82 Value memrefSource, memrefDest;
83
84 auto foldSrcResult =
85 foldMemrefViewOp(rewriter, loc, op.getSrc(), op.getSrcIndices(),
86 sourceIndices, memrefSource, "source");
87
88 if (failed(foldSrcResult)) {
89 memrefSource = op.getSrc();
90 sourceIndices = op.getSrcIndices();
91 }
92
93 auto foldDstResult =
94 foldMemrefViewOp(rewriter, loc, op.getDst(), op.getDstIndices(),
95 destIndices, memrefDest, "destination");
96
97 if (failed(foldDstResult)) {
98 memrefDest = op.getDst();
99 destIndices = op.getDstIndices();
100 }
101
102 rewriter.replaceOpWithNewOp<GatherToLDSOp>(op, memrefSource, sourceIndices,
103 memrefDest, destIndices,
104 op.getTransferType());
105
106 return success();
107 }
108};
109
114} // namespace mlir::amdgpu
return success()
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
OpT getOperation()
Return the current operation being transformed.
Definition Pass.h:378
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
MLIRContext & getContext()
Return the MLIR context for the current operation being transformed.
Definition Pass.h:177
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
void resolveIndicesIntoOpWithOffsetsAndStrides(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > mixedSourceOffsets, ArrayRef< OpFoldResult > mixedSourceStrides, const llvm::SmallBitVector &rankReducedDims, ArrayRef< OpFoldResult > consumerIndices, SmallVectorImpl< Value > &resolvedIndices)
Given the 'consumerIndices' of a load/store operation operating on an op with offsets and strides,...
void populateAmdgpuFoldMemRefOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
static LogicalResult foldMemrefViewOp(PatternRewriter &rewriter, Location loc, Value view, mlir::OperandRange indices, SmallVectorImpl< Value > &resolvedIndices, Value &memrefBase, StringRef role)
LogicalResult resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, memref::CollapseShapeOp collapseShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices)
Given the 'indices' of a load/store operation where the memref is a result of a collapse_shape op,...
LogicalResult resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, memref::ExpandShapeOp expandShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices, bool startsInbounds)
Given the 'indices' of a load/store operation where the memref is a result of a expand_shape op,...
const FrozenRewritePatternSet & patterns
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener=nullptr)
A fast walk-based pattern rewrite driver.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
void runOnOperation() override
The polymorphic API that runs the pass over the currently held operation.
LogicalResult matchAndRewrite(GatherToLDSOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...