18 #include "llvm/ADT/STLExtras.h"
24 if (!type.hasStaticShape())
34 int64_t runningStride = 1;
35 int64_t curDim = strides.size() - 1;
37 while (curDim >= 0 && strides[curDim] == runningStride) {
38 runningStride *= type.getDimSize(curDim);
43 while (curDim >= 0 && type.getDimSize(curDim) == 1) {
55 unsigned sourceRank = sizes.size();
56 assert(sizes.size() == strides.size() &&
57 "expected as many sizes as strides for a memref");
61 assert(indicesVec.size() == strides.size() &&
62 "expected as many indices as rank of memref");
72 for (
unsigned i = 0; i < sourceRank; ++i) {
73 unsigned offsetIdx = 2 * i;
74 addMulMap = addMulMap + symbols[offsetIdx] * symbols[offsetIdx + 1];
75 offsetValues[offsetIdx] = indicesVec[i];
76 offsetValues[offsetIdx + 1] = strides[i];
78 mulMap = mulMap * symbols[i];
82 int64_t scaler = dstBits / srcBits;
83 addMulMap = addMulMap.
floorDiv(scaler);
87 builder, loc, addMulMap, offsetValues);
95 builder, loc, s0.
floorDiv(scaler), {offset});
97 return {{adjustBaseOffset, linearizedSize}, linearizedIndices};
105 if (!sizes.empty()) {
109 for (
int index = sizes.size() - 1; index > 0; --index) {
111 builder, loc, s0 * s1,
117 std::tie(linearizedMemRefInfo, std::ignore) =
120 return linearizedMemRefInfo;
128 std::vector<Operation *> opUses;
131 if (isa<memref::DeallocOp>(useOp) ||
133 !mlir::hasEffect<MemoryEffects::Read>(useOp)) ||
135 opUses.push_back(useOp);
140 uses.insert(uses.end(), opUses.begin(), opUses.end());
145 std::vector<Operation *> opToErase;
146 parentOp->
walk([&](memref::AllocOp op) {
147 std::vector<Operation *> candidates;
149 opToErase.insert(opToErase.end(), candidates.begin(), candidates.end());
150 opToErase.push_back(op.getOperation());
165 for (int64_t r = strides.size() - 1; r > 0; --r) {
167 builder, loc, s0 * s1, {strides[r], sizes[r]});
180 while (
auto op = source.getDefiningOp()) {
181 if (
auto subViewOp = dyn_cast<memref::SubViewOp>(op);
182 subViewOp && subViewOp.hasZeroOffset() && subViewOp.hasUnitStride()) {
185 source = cast<MemrefValue>(subViewOp.getSource());
186 }
else if (
auto castOp = dyn_cast<memref::CastOp>(op)) {
188 source = castOp.getSource();
197 while (
auto op = source.getDefiningOp()) {
198 if (
auto subView = dyn_cast<memref::SubViewOp>(op)) {
199 source = cast<MemrefValue>(subView.getSource());
200 }
else if (
auto cast = dyn_cast<memref::CastOp>(op)) {
201 source = cast.getSource();
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineConstantExpr(int64_t constant)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
unsigned getNumRegions()
Returns the number of regions held by this operation.
unsigned getNumResults()
Return the number of results held by this operation.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
static bool resultIsNotRead(Operation *op, std::vector< Operation * > &uses)
Returns true if all the uses of op are not read/load.
MemrefValue skipFullyAliasingOperations(MemrefValue source)
Walk up the source chain until an operation that changes/defines the view of memory is found (i....
void eraseDeadAllocAndStores(RewriterBase &rewriter, Operation *parentOp)
MemrefValue skipSubViewsAndCasts(MemrefValue source)
Walk up the source chain until something an op other than a memref.subview or memref....
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
static SmallVector< OpFoldResult > computeSuffixProductIRBlockImpl(Location loc, OpBuilder &builder, ArrayRef< OpFoldResult > sizes, OpFoldResult unit)
SmallVector< OpFoldResult > computeSuffixProductIRBlock(Location loc, OpBuilder &builder, ArrayRef< OpFoldResult > sizes)
Given a set of sizes, return the suffix product.
Include the generated interface declarations.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
TypedValue< BaseMemRefType > MemrefValue
A value with a memref type.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
For a memref with offset, sizes and strides, returns the offset and size to use for the linearized me...