29 #include "llvm/ADT/TypeSwitch.h"
33 #define GEN_PASS_DEF_FLATTENMEMREFSPASS
34 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
42 if (
Attribute offsetAttr = dyn_cast<Attribute>(in)) {
44 rewriter, loc, cast<IntegerAttr>(offsetAttr).getInt());
46 return cast<Value>(in);
57 auto sourceType = cast<MemRefType>(source.
getType());
58 if (
failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset))) {
62 memref::ExtractStridedMetadataOp stridedMetadata =
63 memref::ExtractStridedMetadataOp::create(rewriter, loc, source);
65 auto typeBit = sourceType.getElementType().getIntOrFloatBitWidth();
68 std::tie(linearizedInfo, linearizedIndices) =
70 rewriter, loc, typeBit, typeBit,
71 stridedMetadata.getConstifiedMixedOffset(),
72 stridedMetadata.getConstifiedMixedSizes(),
73 stridedMetadata.getConstifiedMixedStrides(),
76 return std::make_pair(
77 memref::ReinterpretCastOp::create(
78 rewriter, loc, source,
88 auto type = cast<MemRefType>(val.
getType());
89 return type.getRank() > 1;
93 auto type = cast<MemRefType>(val.
getType());
94 return type.getLayout().isIdentity() ||
95 isa<StridedLayoutAttr>(type.getLayout());
101 .template Case<memref::LoadOp, memref::StoreOp, memref::AllocaOp,
102 memref::AllocOp>([](
auto op) {
return op.getMemref(); })
103 .
template Case<vector::LoadOp, vector::StoreOp, vector::MaskedLoadOp,
104 vector::MaskedStoreOp, vector::TransferReadOp,
105 vector::TransferWriteOp>(
106 [](
auto op) {
return op.getBase(); })
107 .Default([](
auto) {
return Value{}; });
110 template <
typename T>
111 static void castAllocResult(T oper, T newOper,
Location loc,
113 memref::ExtractStridedMetadataOp stridedMetadata =
114 memref::ExtractStridedMetadataOp::create(rewriter, loc, oper);
116 oper, cast<MemRefType>(oper.getType()), newOper,
118 stridedMetadata.getConstifiedMixedSizes(),
119 stridedMetadata.getConstifiedMixedStrides());
122 template <
typename T>
127 .template Case<memref::AllocOp>([&](
auto oper) {
128 auto newAlloc = memref::AllocOp::create(
129 rewriter, loc, cast<MemRefType>(flatMemref.
getType()),
130 oper.getAlignmentAttr());
131 castAllocResult(oper, newAlloc, loc, rewriter);
133 .
template Case<memref::AllocaOp>([&](
auto oper) {
134 auto newAlloca = memref::AllocaOp::create(
135 rewriter, loc, cast<MemRefType>(flatMemref.
getType()),
136 oper.getAlignmentAttr());
137 castAllocResult(oper, newAlloca, loc, rewriter);
139 .
template Case<memref::LoadOp>([&](
auto op) {
141 memref::LoadOp::create(rewriter, loc, op->getResultTypes(),
143 newLoad->setAttrs(op->getAttrs());
144 rewriter.
replaceOp(op, newLoad.getResult());
146 .
template Case<memref::StoreOp>([&](
auto op) {
148 memref::StoreOp::create(rewriter, loc, op->getOperands().front(),
150 newStore->setAttrs(op->getAttrs());
153 .
template Case<vector::LoadOp>([&](
auto op) {
155 vector::LoadOp::create(rewriter, loc, op->getResultTypes(),
157 newLoad->setAttrs(op->getAttrs());
158 rewriter.
replaceOp(op, newLoad.getResult());
160 .
template Case<vector::StoreOp>([&](
auto op) {
162 vector::StoreOp::create(rewriter, loc, op->getOperands().front(),
164 newStore->setAttrs(op->getAttrs());
167 .
template Case<vector::MaskedLoadOp>([&](
auto op) {
168 auto newMaskedLoad = vector::MaskedLoadOp::create(
169 rewriter, loc, op.getType(), flatMemref,
ValueRange{offset},
170 op.getMask(), op.getPassThru());
171 newMaskedLoad->setAttrs(op->getAttrs());
172 rewriter.
replaceOp(op, newMaskedLoad.getResult());
174 .
template Case<vector::MaskedStoreOp>([&](
auto op) {
175 auto newMaskedStore = vector::MaskedStoreOp::create(
176 rewriter, loc, flatMemref,
ValueRange{offset}, op.getMask(),
177 op.getValueToStore());
178 newMaskedStore->setAttrs(op->getAttrs());
181 .
template Case<vector::TransferReadOp>([&](
auto op) {
182 auto newTransferRead = vector::TransferReadOp::create(
183 rewriter, loc, op.getType(), flatMemref,
ValueRange{offset},
185 rewriter.
replaceOp(op, newTransferRead.getResult());
187 .
template Case<vector::TransferWriteOp>([&](
auto op) {
188 auto newTransferWrite = vector::TransferWriteOp::create(
189 rewriter, loc, op.getVector(), flatMemref,
ValueRange{offset});
190 rewriter.
replaceOp(op, newTransferWrite);
192 .Default([&](
auto op) {
193 op->emitOpError(
"unimplemented: do not know how to replace op.");
197 template <
typename T>
199 if constexpr (std::is_same_v<T, memref::AllocaOp> ||
200 std::is_same_v<T, memref::AllocOp>) {
203 return op.getIndices();
207 template <
typename T>
210 .template Case<vector::TransferReadOp, vector::TransferWriteOp>(
215 auto permutationMap = oper.getPermutationMap();
216 if (!permutationMap.isIdentity() &&
217 !permutationMap.isMinorIdentity()) {
218 return rewriter.notifyMatchFailure(
219 oper,
"only identity permutation map is supported");
221 mlir::ArrayAttr inbounds = oper.getInBounds();
222 if (llvm::any_of(inbounds, [](
Attribute attr) {
223 return !cast<BoolAttr>(attr).getValue();
226 "only inbounds are supported");
230 .Default([&](
auto op) {
return success(); });
233 template <
typename T>
236 LogicalResult matchAndRewrite(T op,
238 LogicalResult canFlatten = canBeFlattened(op, rewriter);
243 Value memref = getTargetMemref(op);
247 rewriter, op->getLoc(), memref, getIndices<T>(op));
248 replaceOp<T>(op, rewriter, flatMemref, offset);
253 struct FlattenMemrefsPass
254 :
public mlir::memref::impl::FlattenMemrefsPassBase<FlattenMemrefsPass> {
258 registry.
insert<affine::AffineDialect, arith::ArithDialect,
259 memref::MemRefDialect, vector::VectorDialect>();
262 void runOnOperation()
override {
268 return signalPassFailure();
275 patterns.insert<MemRefRewritePattern<memref::LoadOp>,
276 MemRefRewritePattern<memref::StoreOp>,
277 MemRefRewritePattern<memref::AllocOp>,
278 MemRefRewritePattern<memref::AllocaOp>,
279 MemRefRewritePattern<vector::LoadOp>,
280 MemRefRewritePattern<vector::StoreOp>,
281 MemRefRewritePattern<vector::TransferReadOp>,
282 MemRefRewritePattern<vector::TransferWriteOp>,
283 MemRefRewritePattern<vector::MaskedLoadOp>,
284 MemRefRewritePattern<vector::MaskedStoreOp>>(
static std::pair< Value, Value > getFlattenMemrefAndOffset(OpBuilder &rewriter, Location loc, Value source, ValueRange indices)
Returns a collapsed memref and the linearized index to access the element at the specified indices.
static bool checkLayout(Value val)
static bool needFlattening(Value val)
static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc, OpFoldResult in)
static MLIRContext * getContext(OpFoldResult val)
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
void populateFlattenMemrefsPatterns(RewritePatternSet &patterns)
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...
OpFoldResult linearizedOffset