19 #include "llvm/ADT/STLExtras.h"
25 if (!type.hasStaticShape())
35 int64_t runningStride = 1;
36 int64_t curDim = strides.size() - 1;
38 while (curDim >= 0 && strides[curDim] == runningStride) {
39 runningStride *= type.getDimSize(curDim);
44 while (curDim >= 0 && type.getDimSize(curDim) == 1) {
56 unsigned sourceRank = sizes.size();
57 assert(sizes.size() == strides.size() &&
58 "expected as many sizes as strides for a memref");
62 assert(indicesVec.size() == strides.size() &&
63 "expected as many indices as rank of memref");
73 for (
unsigned i = 0; i < sourceRank; ++i) {
74 unsigned offsetIdx = 2 * i;
75 addMulMap = addMulMap + symbols[offsetIdx] * symbols[offsetIdx + 1];
76 offsetValues[offsetIdx] = indicesVec[i];
77 offsetValues[offsetIdx + 1] = strides[i];
79 mulMap = mulMap * symbols[i];
83 int64_t scaler = dstBits / srcBits;
87 builder, loc, addMulMap.
floorDiv(scaler), offsetValues);
95 builder, loc, s0.
floorDiv(scaler), {offset});
98 builder, loc, addMulMap % scaler, offsetValues);
100 return {{adjustBaseOffset, linearizedSize, intraVectorOffset},
109 if (!sizes.empty()) {
113 for (
int index = sizes.size() - 1; index > 0; --index) {
115 builder, loc, s0 * s1,
121 std::tie(linearizedMemRefInfo, std::ignore) =
124 return linearizedMemRefInfo;
132 std::vector<Operation *> opUses;
135 if (isa<memref::DeallocOp>(useOp) ||
137 !mlir::hasEffect<MemoryEffects::Read>(useOp)) ||
139 opUses.push_back(useOp);
144 uses.insert(uses.end(), opUses.begin(), opUses.end());
149 std::vector<Operation *> opToErase;
150 parentOp->
walk([&](memref::AllocOp op) {
151 std::vector<Operation *> candidates;
153 opToErase.insert(opToErase.end(), candidates.begin(), candidates.end());
154 opToErase.push_back(op.getOperation());
169 for (int64_t r = strides.size() - 1; r > 0; --r) {
171 builder, loc, s0 * s1, {strides[r], sizes[r]});
184 while (
auto op = source.getDefiningOp()) {
185 if (
auto subViewOp = dyn_cast<memref::SubViewOp>(op);
186 subViewOp && subViewOp.hasZeroOffset() && subViewOp.hasUnitStride()) {
189 source = cast<MemrefValue>(subViewOp.getSource());
190 }
else if (
auto castOp = dyn_cast<memref::CastOp>(op)) {
192 source = castOp.getSource();
201 while (
auto op = source.getDefiningOp()) {
202 if (
auto viewLike = dyn_cast<ViewLikeOpInterface>(op)) {
203 source = cast<MemrefValue>(viewLike.getViewSource());
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineConstantExpr(int64_t constant)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
unsigned getNumRegions()
Returns the number of regions held by this operation.
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
unsigned getNumResults()
Return the number of results held by this operation.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
static bool resultIsNotRead(Operation *op, std::vector< Operation * > &uses)
Returns true if all the uses of op are not read/load.
MemrefValue skipFullyAliasingOperations(MemrefValue source)
Walk up the source chain until an operation that changes/defines the view of memory is found (i....
void eraseDeadAllocAndStores(RewriterBase &rewriter, Operation *parentOp)
Track temporary allocations that are never read from.
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
MemrefValue skipViewLikeOps(MemrefValue source)
Walk up the source chain until we find an operation that is not a view of the source memref (i....
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
static SmallVector< OpFoldResult > computeSuffixProductIRBlockImpl(Location loc, OpBuilder &builder, ArrayRef< OpFoldResult > sizes, OpFoldResult unit)
SmallVector< OpFoldResult > computeSuffixProductIRBlock(Location loc, OpBuilder &builder, ArrayRef< OpFoldResult > sizes)
Given a set of sizes, return the suffix product.
Include the generated interface declarations.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
TypedValue< BaseMemRefType > MemrefValue
A value with a memref type.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...