27 #include "llvm/ADT/STLExtras.h"
29 #define DEBUG_TYPE "vector-transfer-split"
37 VectorTransferOpInterface xferOp) {
38 assert(xferOp.getPermutationMap().isMinorIdentity() &&
39 "Expected minor identity map");
41 xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
45 if (xferOp.isDimInBounds(resultIdx))
49 int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx);
52 {xferOp.getIndices()[indicesIdx]});
57 if (maybeCstSum && maybeCstDimSz && *maybeCstSum <= *maybeCstDimSz)
60 arith::CmpIOp::create(b, loc, arith::CmpIPredicate::sle,
65 inBoundsCond = arith::AndIOp::create(b, loc, inBoundsCond, cond);
108 if (xferOp.getTransferRank() == 0)
112 if (!xferOp.getPermutationMap().isMinorIdentity())
115 if (!xferOp.hasOutOfBoundsDim())
120 if (isa<scf::IfOp>(xferOp->getParentOp()))
135 if (memref::CastOp::areCastCompatible(aT, bT))
137 if (aT.getRank() != bT.getRank())
139 int64_t aOffset, bOffset;
141 if (failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
142 failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
143 aStrides.size() != bStrides.size())
149 resStrides(bT.getRank(), 0);
150 for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {
152 (aShape[idx] == bShape[idx]) ? aShape[idx] : ShapedType::kDynamic;
154 (aStrides[idx] == bStrides[idx]) ? aStrides[idx] : ShapedType::kDynamic;
156 resOffset = (aOffset == bOffset) ? aOffset : ShapedType::kDynamic;
158 resShape, aT.getElementType(),
166 MemRefType compatibleMemRefType) {
167 MemRefType sourceType = cast<MemRefType>(memref.
getType());
169 if (sourceType.getMemorySpace() != compatibleMemRefType.getMemorySpace()) {
171 sourceType.getShape(), sourceType.getElementType(),
172 sourceType.getLayout(), compatibleMemRefType.getMemorySpace());
174 memref::MemorySpaceCastOp::create(b, memref.
getLoc(), sourceType, res);
176 if (sourceType == compatibleMemRefType)
178 return memref::CastOp::create(b, memref.
getLoc(), compatibleMemRefType, res);
184 static std::pair<Value, Value>
188 int64_t memrefRank = xferOp.getShapedType().getRank();
190 assert(memrefRank == cast<MemRefType>(alloc.
getType()).getRank() &&
191 "Expected memref rank to match the alloc rank");
193 xferOp.getIndices().take_front(xferOp.getLeadingShapedRank());
195 sizes.append(leadingIndices.begin(), leadingIndices.end());
196 auto isaWrite = isa<vector::TransferWriteOp>(xferOp);
197 xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
200 memref::DimOp::create(b, xferOp.getLoc(), xferOp.getBase(), indicesIdx);
201 Value dimAlloc = memref::DimOp::create(b, loc, alloc, resultIdx);
202 Value index = xferOp.getIndices()[indicesIdx];
209 affine::AffineMinOp::create(b, loc, index.
getType(), maps[0],
211 sizes.push_back(affineMin);
218 auto copySrc = memref::SubViewOp::create(
219 b, loc, isaWrite ? alloc : xferOp.getBase(), srcIndices, sizes, strides);
220 auto copyDest = memref::SubViewOp::create(
221 b, loc, isaWrite ? xferOp.getBase() : alloc, destIndices, sizes, strides);
222 return std::make_pair(copySrc, copyDest);
247 MemRefType compatibleMemRefType,
Value alloc) {
250 Value memref = xferOp.getBase();
251 return scf::IfOp::create(
252 b, loc, inBoundsCond,
256 llvm::append_range(viewAndIndices, xferOp.getIndices());
257 scf::YieldOp::create(b, loc, viewAndIndices);
260 linalg::FillOp::create(b, loc,
ValueRange{xferOp.getPadding()},
266 rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
268 memref::CopyOp::create(b, loc, copyArgs.first, copyArgs.second);
272 viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
274 scf::YieldOp::create(b, loc, viewAndIndices);
299 Value inBoundsCond, MemRefType compatibleMemRefType,
Value alloc) {
301 scf::IfOp fullPartialIfOp;
303 Value memref = xferOp.getBase();
304 return scf::IfOp::create(
305 b, loc, inBoundsCond,
309 llvm::append_range(viewAndIndices, xferOp.getIndices());
310 scf::YieldOp::create(b, loc, viewAndIndices);
314 Value vector = cast<VectorTransferOpInterface>(newXfer).getVector();
315 memref::StoreOp::create(
317 vector::TypeCastOp::create(
323 viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
325 scf::YieldOp::create(b, loc, viewAndIndices);
347 MemRefType compatibleMemRefType,
Value alloc) {
350 Value memref = xferOp.getBase();
351 return scf::IfOp::create(
352 b, loc, inBoundsCond,
357 llvm::append_range(viewAndIndices, xferOp.getIndices());
358 scf::YieldOp::create(b, loc, viewAndIndices);
364 viewAndIndices.insert(viewAndIndices.end(),
365 xferOp.getTransferRank(), zero);
366 scf::YieldOp::create(b, loc, viewAndIndices);
385 vector::TransferWriteOp xferOp,
388 auto notInBounds = arith::XOrIOp::create(
393 rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
395 memref::CopyOp::create(b, loc, copyArgs.first, copyArgs.second);
413 vector::TransferWriteOp xferOp,
417 auto notInBounds = arith::XOrIOp::create(
421 Value load = memref::LoadOp::create(
423 vector::TypeCastOp::create(
426 mapping.
map(xferOp.getVector(), load);
427 b.
clone(*xferOp.getOperation(), mapping);
442 if (!isa<scf::ForOp, affine::AffineForOp>(parent))
445 assert(scope &&
"Expected op to be inside automatic allocation scope");
517 if (
options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) {
519 xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
528 "Expected splitFullAndPartialTransferPrecondition to hold");
530 auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
531 auto xferWriteOp = dyn_cast<vector::TransferWriteOp>(xferOp.getOperation());
533 if (!(xferReadOp || xferWriteOp))
535 if (xferWriteOp && xferWriteOp.getMask())
537 if (xferReadOp && xferReadOp.getMask())
544 b, cast<VectorTransferOpInterface>(xferOp.getOperation()));
554 "AutomaticAllocationScope with >1 regions");
556 auto shape = xferOp.getVectorType().getShape();
557 Type elementType = xferOp.getVectorType().getElementType();
558 alloc = memref::AllocaOp::create(b, scope->
getLoc(),
563 MemRefType compatibleMemRefType =
565 cast<MemRefType>(alloc.
getType()));
566 if (!compatibleMemRefType)
571 returnTypes[0] = compatibleMemRefType;
573 if (
auto xferReadOp =
574 dyn_cast<vector::TransferReadOp>(xferOp.getOperation())) {
576 scf::IfOp fullPartialIfOp =
577 options.vectorTransferSplit == VectorTransferSplit::VectorTransfer
580 compatibleMemRefType, alloc)
582 inBoundsCond, compatibleMemRefType,
585 *ifOp = fullPartialIfOp;
588 for (
unsigned i = 0, e = returnTypes.size(); i != e; ++i)
589 xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
592 xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
598 auto xferWriteOp = cast<vector::TransferWriteOp>(xferOp.getOperation());
602 b, xferWriteOp, returnTypes, inBoundsCond, compatibleMemRefType, alloc);
608 mapping.
map(xferWriteOp.getBase(), memrefAndIndices.front());
609 mapping.
map(xferWriteOp.getIndices(), memrefAndIndices.drop_front());
611 clone->
setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr);
615 if (
options.vectorTransferSplit == VectorTransferSplit::VectorTransfer)
628 struct VectorTransferFullPartialRewriter :
public RewritePattern {
629 using FilterConstraintType =
630 std::function<LogicalResult(VectorTransferOpInterface op)>;
632 explicit VectorTransferFullPartialRewriter(
635 FilterConstraintType filter =
636 [](VectorTransferOpInterface op) {
return success(); },
639 filter(std::move(filter)) {}
642 LogicalResult matchAndRewrite(
Operation *op,
647 FilterConstraintType filter;
652 LogicalResult VectorTransferFullPartialRewriter::matchAndRewrite(
654 auto xferOp = dyn_cast<VectorTransferOpInterface>(op);
656 failed(filter(xferOp)))
static llvm::ManagedStatic< PassManagerOptions > options
static scf::IfOp createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc)
Given an xferOp for which:
static void createFullPartialVectorTransferWrite(RewriterBase &b, vector::TransferWriteOp xferOp, Value inBoundsCond, Value alloc)
Given an xferOp for which:
static ValueRange getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp, TypeRange returnTypes, Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc)
Given an xferOp for which:
static scf::IfOp createFullPartialVectorTransferRead(RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc)
Given an xferOp for which:
static LogicalResult splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp)
Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds masking) fast path and a ...
static std::pair< Value, Value > createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp, Value alloc)
Operates under a scoped context to build the intersection between the view xferOp....
static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT)
Given two MemRefTypes aT and bT, return a MemRefType to which both can be cast.
static Value createInBoundsCond(RewriterBase &b, VectorTransferOpInterface xferOp)
Build the condition to ensure that a particular VectorTransferOpInterface is in-bounds.
static Operation * getAutomaticAllocationScope(Operation *op)
static Value castToCompatibleMemRefType(OpBuilder &b, Value memref, MemRefType compatibleMemRefType)
Casts the given memref to a compatible memref type.
Base type for affine expression.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineConstantExpr(int64_t constant)
IntegerAttr getI64IntegerAttr(int64_t value)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
This class represents a single result from folding an operation.
A trait of region holding operations that define a new scope for automatic allocations,...
Operation is the basic unit of execution within MLIR.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
LogicalResult splitFullAndPartialTransfer(RewriterBase &b, VectorTransferOpInterface xferOp, VectorTransformsOptions options=VectorTransformsOptions(), scf::IfOp *ifOp=nullptr)
Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds masking) fastpath and a s...
void populateVectorTransferFullPartialPatterns(RewritePatternSet &patterns, const VectorTransformsOptions &options)
Populate patterns with the following patterns.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.