29#include "llvm/ADT/TypeSwitch.h"
33#define GEN_PASS_DEF_FLATTENMEMREFSPASS
34#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
42 if (
Attribute offsetAttr = dyn_cast<Attribute>(in)) {
44 rewriter, loc, cast<IntegerAttr>(offsetAttr).getInt());
46 return cast<Value>(in);
57 auto sourceType = cast<MemRefType>(source.
getType());
58 if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset))) {
62 memref::ExtractStridedMetadataOp stridedMetadata =
63 memref::ExtractStridedMetadataOp::create(rewriter, loc, source);
65 auto typeBit = sourceType.getElementType().getIntOrFloatBitWidth();
68 std::tie(linearizedInfo, linearizedIndices) =
70 rewriter, loc, typeBit, typeBit,
71 stridedMetadata.getConstifiedMixedOffset(),
72 stridedMetadata.getConstifiedMixedSizes(),
73 stridedMetadata.getConstifiedMixedStrides(),
76 return std::make_pair(
77 memref::ReinterpretCastOp::create(
78 rewriter, loc, source,
88 auto type = cast<MemRefType>(val.
getType());
89 return type.getRank() > 1;
93 auto type = cast<MemRefType>(val.
getType());
94 return type.getLayout().isIdentity() ||
95 isa<StridedLayoutAttr>(type.getLayout());
101 .template Case<memref::LoadOp, memref::StoreOp, memref::AllocaOp,
102 memref::AllocOp>([](
auto op) {
return op.getMemref(); })
103 .
template Case<vector::LoadOp, vector::StoreOp, vector::MaskedLoadOp,
104 vector::MaskedStoreOp, vector::TransferReadOp,
105 vector::TransferWriteOp>(
106 [](
auto op) {
return op.getBase(); })
115 .Case([&](memref::LoadOp op) {
117 memref::LoadOp::create(rewriter, loc, op->getResultTypes(),
119 newLoad->setAttrs(op->getAttrs());
120 rewriter.
replaceOp(op, newLoad.getResult());
122 .Case([&](memref::StoreOp op) {
124 memref::StoreOp::create(rewriter, loc, op->getOperands().front(),
126 newStore->setAttrs(op->getAttrs());
129 .Case([&](vector::LoadOp op) {
131 vector::LoadOp::create(rewriter, loc, op->getResultTypes(),
133 newLoad->setAttrs(op->getAttrs());
134 rewriter.
replaceOp(op, newLoad.getResult());
136 .Case([&](vector::StoreOp op) {
138 vector::StoreOp::create(rewriter, loc, op->getOperands().front(),
140 newStore->setAttrs(op->getAttrs());
143 .Case([&](vector::MaskedLoadOp op) {
144 auto newMaskedLoad = vector::MaskedLoadOp::create(
145 rewriter, loc, op.getType(), flatMemref,
ValueRange{offset},
146 op.getMask(), op.getPassThru());
147 newMaskedLoad->setAttrs(op->getAttrs());
148 rewriter.
replaceOp(op, newMaskedLoad.getResult());
150 .Case([&](vector::MaskedStoreOp op) {
151 auto newMaskedStore = vector::MaskedStoreOp::create(
152 rewriter, loc, flatMemref,
ValueRange{offset}, op.getMask(),
153 op.getValueToStore());
154 newMaskedStore->setAttrs(op->getAttrs());
157 .Case([&](vector::TransferReadOp op) {
158 auto newTransferRead = vector::TransferReadOp::create(
159 rewriter, loc, op.getType(), flatMemref,
ValueRange{offset},
161 rewriter.
replaceOp(op, newTransferRead.getResult());
163 .Case([&](vector::TransferWriteOp op) {
164 auto newTransferWrite = vector::TransferWriteOp::create(
165 rewriter, loc, op.getVector(), flatMemref,
ValueRange{offset});
166 rewriter.
replaceOp(op, newTransferWrite);
168 .Default([&](
auto op) {
169 op->emitOpError(
"unimplemented: do not know how to replace op.");
175 return op.getIndices();
181 .template Case<vector::TransferReadOp, vector::TransferWriteOp>(
186 auto permutationMap = oper.getPermutationMap();
187 if (!permutationMap.isIdentity() &&
188 !permutationMap.isMinorIdentity()) {
190 oper,
"only identity permutation map is supported");
192 mlir::ArrayAttr inbounds = oper.getInBounds();
193 if (llvm::any_of(inbounds, [](
Attribute attr) {
194 return !cast<BoolAttr>(attr).getValue();
197 "only inbounds are supported");
201 .Default([&](
auto op) {
return success(); });
216template <
typename AllocLikeOp>
218 using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
219 LogicalResult matchAndRewrite(AllocLikeOp op,
220 PatternRewriter &rewriter)
const override {
224 Location loc = op->getLoc();
225 auto memrefType = cast<MemRefType>(op.getType());
226 auto elemType = memrefType.getElementType();
227 if (!elemType.isIntOrFloat())
229 unsigned elemBitWidth = elemType.getIntOrFloatBitWidth();
231 SmallVector<OpFoldResult> sizes = op.getMixedSizes();
233 int64_t staticOffset;
234 SmallVector<int64_t> staticStrides;
235 if (
failed(memrefType.getStridesAndOffset(staticStrides, staticOffset)))
237 if (staticOffset == ShapedType::kDynamic)
239 SmallVector<OpFoldResult> strides;
240 strides.reserve(staticStrides.size());
241 for (int64_t stride : staticStrides) {
242 if (stride == ShapedType::kDynamic)
244 "dynamic stride cannot be computed");
250 memref::LinearizedMemRefInfo linearizedInfo;
251 OpFoldResult linearizedOffset;
252 std::tie(linearizedInfo, linearizedOffset) =
254 rewriter, loc, elemBitWidth, elemBitWidth, rewriter.
getIndexAttr(0),
256 (void)linearizedOffset;
263 if (staticOffset != 0) {
267 rewriter, loc, s0 + staticOffset, {flatSizeOfr});
272 int64_t flatDimSize = ShapedType::kDynamic;
273 if (
auto attr = dyn_cast<Attribute>(flatSizeOfr))
274 if (
auto intAttr = dyn_cast<IntegerAttr>(attr))
275 flatDimSize = intAttr.getInt();
277 auto flatMemrefType =
278 MemRefType::get({flatDimSize}, memrefType.getElementType(),
279 StridedLayoutAttr::get(rewriter.
getContext(), 0, {1}),
280 memrefType.getMemorySpace());
283 SmallVector<Value, 1> dynSizes;
284 if (flatDimSize == ShapedType::kDynamic)
287 auto newOp = AllocLikeOp::create(rewriter, loc, flatMemrefType, dynSizes,
288 op.getAlignmentAttr());
290 op, cast<MemRefType>(op.getType()), newOp,
298 using OpRewritePattern<T>::OpRewritePattern;
299 LogicalResult matchAndRewrite(T op,
300 PatternRewriter &rewriter)
const override {
301 LogicalResult canFlatten = canBeFlattened(op, rewriter);
305 Value memref = getTargetMemref(op);
310 rewriter, op->getLoc(), memref, getIndices<T>(op));
311 replaceOp<T>(op, rewriter, flatMemref, offset);
316struct FlattenMemrefsPass
320 void getDependentDialects(DialectRegistry ®istry)
const override {
321 registry.
insert<affine::AffineDialect, arith::ArithDialect,
322 memref::MemRefDialect, vector::VectorDialect>();
325 void runOnOperation()
override {
331 return signalPassFailure();
339 patterns.
insert<MemRefRewritePattern<vector::LoadOp>,
340 MemRefRewritePattern<vector::StoreOp>,
341 MemRefRewritePattern<vector::TransferReadOp>,
342 MemRefRewritePattern<vector::TransferWriteOp>,
343 MemRefRewritePattern<vector::MaskedLoadOp>,
344 MemRefRewritePattern<vector::MaskedStoreOp>>(
349 patterns.
insert<MemRefRewritePattern<memref::LoadOp>,
350 MemRefRewritePattern<memref::StoreOp>,
351 AllocLikeFlattenPattern<memref::AllocOp>,
352 AllocLikeFlattenPattern<memref::AllocaOp>>(
static bool checkLayout(Value val)
static bool needFlattening(Value val)
static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc, OpFoldResult in)
static std::pair< Value, Value > getFlattenMemrefAndOffset(OpBuilder &rewriter, Location loc, Value source, ValueRange indices)
Returns a collapsed memref and the linearized index to access the element at the specified indices.
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
MLIRContext * getContext() const
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
void populateFlattenMemrefsPatterns(RewritePatternSet &patterns)
void populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns)
void populateFlattenVectorOpsOnMemrefPatterns(RewritePatternSet &patterns)
Patterns for flattening multi-dimensional memref operations into one-dimensional memref operations.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...
OpFoldResult linearizedSize
OpFoldResult linearizedOffset