31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/TypeSwitch.h"
38 #define GEN_PASS_DEF_FLATTENMEMREFSPASS
39 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
47 if (
Attribute offsetAttr = dyn_cast<Attribute>(in)) {
48 return rewriter.
create<arith::ConstantIndexOp>(
49 loc, cast<IntegerAttr>(offsetAttr).getInt());
51 return cast<Value>(in);
62 auto sourceType = cast<MemRefType>(source.
getType());
63 if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset))) {
67 memref::ExtractStridedMetadataOp stridedMetadata =
68 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, source);
70 auto typeBit = sourceType.getElementType().getIntOrFloatBitWidth();
73 std::tie(linearizedInfo, linearizedIndices) =
75 rewriter, loc, typeBit, typeBit,
76 stridedMetadata.getConstifiedMixedOffset(),
77 stridedMetadata.getConstifiedMixedSizes(),
78 stridedMetadata.getConstifiedMixedStrides(),
81 return std::make_pair(
82 rewriter.
create<memref::ReinterpretCastOp>(
93 auto type = cast<MemRefType>(val.
getType());
94 return type.getRank() > 1;
98 auto type = cast<MemRefType>(val.
getType());
99 return type.getLayout().isIdentity() ||
100 isa<StridedLayoutAttr>(type.getLayout());
106 .template Case<memref::LoadOp, memref::StoreOp, memref::AllocaOp,
107 memref::AllocOp>([](
auto op) {
return op.getMemref(); })
108 .
template Case<vector::LoadOp, vector::StoreOp, vector::MaskedLoadOp,
109 vector::MaskedStoreOp, vector::TransferReadOp,
110 vector::TransferWriteOp>(
111 [](
auto op) {
return op.getBase(); })
112 .Default([](
auto) {
return Value{}; });
115 template <
typename T>
116 static void castAllocResult(T oper, T newOper,
Location loc,
118 memref::ExtractStridedMetadataOp stridedMetadata =
119 rewriter.
create<memref::ExtractStridedMetadataOp>(loc, oper);
121 oper, cast<MemRefType>(oper.getType()), newOper,
123 stridedMetadata.getConstifiedMixedSizes(),
124 stridedMetadata.getConstifiedMixedStrides());
127 template <
typename T>
132 .template Case<memref::AllocOp>([&](
auto oper) {
133 auto newAlloc = rewriter.
create<memref::AllocOp>(
134 loc, cast<MemRefType>(flatMemref.
getType()),
135 oper.getAlignmentAttr());
136 castAllocResult(oper, newAlloc, loc, rewriter);
138 .
template Case<memref::AllocaOp>([&](
auto oper) {
139 auto newAlloca = rewriter.
create<memref::AllocaOp>(
140 loc, cast<MemRefType>(flatMemref.
getType()),
141 oper.getAlignmentAttr());
142 castAllocResult(oper, newAlloca, loc, rewriter);
144 .
template Case<memref::LoadOp>([&](
auto op) {
145 auto newLoad = rewriter.
create<memref::LoadOp>(
146 loc, op->getResultTypes(), flatMemref,
ValueRange{offset});
148 rewriter.
replaceOp(op, newLoad.getResult());
150 .
template Case<memref::StoreOp>([&](
auto op) {
151 auto newStore = rewriter.
create<memref::StoreOp>(
152 loc, op->getOperands().front(), flatMemref,
ValueRange{offset});
156 .
template Case<vector::LoadOp>([&](
auto op) {
157 auto newLoad = rewriter.
create<vector::LoadOp>(
158 loc, op->getResultTypes(), flatMemref,
ValueRange{offset});
160 rewriter.
replaceOp(op, newLoad.getResult());
162 .
template Case<vector::StoreOp>([&](
auto op) {
163 auto newStore = rewriter.
create<vector::StoreOp>(
164 loc, op->getOperands().front(), flatMemref,
ValueRange{offset});
168 .
template Case<vector::MaskedLoadOp>([&](
auto op) {
169 auto newMaskedLoad = rewriter.
create<vector::MaskedLoadOp>(
170 loc, op.getType(), flatMemref,
ValueRange{offset}, op.getMask(),
172 newMaskedLoad->
setAttrs(op->getAttrs());
173 rewriter.
replaceOp(op, newMaskedLoad.getResult());
175 .
template Case<vector::MaskedStoreOp>([&](
auto op) {
176 auto newMaskedStore = rewriter.
create<vector::MaskedStoreOp>(
177 loc, flatMemref,
ValueRange{offset}, op.getMask(),
178 op.getValueToStore());
179 newMaskedStore->
setAttrs(op->getAttrs());
182 .
template Case<vector::TransferReadOp>([&](
auto op) {
183 auto newTransferRead = rewriter.
create<vector::TransferReadOp>(
184 loc, op.getType(), flatMemref,
ValueRange{offset}, op.getPadding());
185 rewriter.
replaceOp(op, newTransferRead.getResult());
187 .
template Case<vector::TransferWriteOp>([&](
auto op) {
188 auto newTransferWrite = rewriter.
create<vector::TransferWriteOp>(
189 loc, op.getVector(), flatMemref,
ValueRange{offset});
190 rewriter.
replaceOp(op, newTransferWrite);
192 .Default([&](
auto op) {
193 op->emitOpError(
"unimplemented: do not know how to replace op.");
197 template <
typename T>
199 if constexpr (std::is_same_v<T, memref::AllocaOp> ||
200 std::is_same_v<T, memref::AllocOp>) {
203 return op.getIndices();
207 template <
typename T>
210 .template Case<vector::TransferReadOp, vector::TransferWriteOp>(
215 auto permutationMap = oper.getPermutationMap();
216 if (!permutationMap.isIdentity() &&
217 !permutationMap.isMinorIdentity()) {
218 return rewriter.notifyMatchFailure(
219 oper,
"only identity permutation map is supported");
221 mlir::ArrayAttr inbounds = oper.getInBounds();
222 if (llvm::any_of(inbounds, [](
Attribute attr) {
223 return !cast<BoolAttr>(attr).getValue();
226 "only inbounds are supported");
230 .Default([&](
auto op) {
return success(); });
233 template <
typename T>
236 LogicalResult matchAndRewrite(T op,
238 LogicalResult canFlatten = canBeFlattened(op, rewriter);
239 if (failed(canFlatten)) {
243 Value memref = getTargetMemref(op);
247 rewriter, op->getLoc(), memref, getIndices<T>(op));
248 replaceOp<T>(op, rewriter, flatMemref, offset);
253 struct FlattenMemrefsPass
254 :
public mlir::memref::impl::FlattenMemrefsPassBase<FlattenMemrefsPass> {
258 registry.
insert<affine::AffineDialect, arith::ArithDialect,
259 memref::MemRefDialect, vector::VectorDialect>();
262 void runOnOperation()
override {
268 return signalPassFailure();
275 patterns.insert<MemRefRewritePattern<memref::LoadOp>,
276 MemRefRewritePattern<memref::StoreOp>,
277 MemRefRewritePattern<memref::AllocOp>,
278 MemRefRewritePattern<memref::AllocaOp>,
279 MemRefRewritePattern<vector::LoadOp>,
280 MemRefRewritePattern<vector::StoreOp>,
281 MemRefRewritePattern<vector::TransferReadOp>,
282 MemRefRewritePattern<vector::TransferWriteOp>,
283 MemRefRewritePattern<vector::MaskedLoadOp>,
284 MemRefRewritePattern<vector::MaskedStoreOp>>(
static std::pair< Value, Value > getFlattenMemrefAndOffset(OpBuilder &rewriter, Location loc, Value source, ValueRange indices)
Returns a collapsed memref and the linearized index to access the element at the specified indices.
static bool checkLayout(Value val)
static bool needFlattening(Value val)
static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc, OpFoldResult in)
static MLIRContext * getContext(OpFoldResult val)
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
void populateFlattenMemrefsPatterns(RewritePatternSet &patterns)
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...
OpFoldResult linearizedOffset