15 #include <type_traits>
29 #include "llvm/ADT/DenseSet.h"
30 #include "llvm/ADT/MapVector.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/Support/CommandLine.h"
33 #include "llvm/Support/Debug.h"
34 #include "llvm/Support/raw_ostream.h"
36 #define DEBUG_TYPE "vector-transfer-split"
44 VectorTransferOpInterface xferOp) {
45 assert(xferOp.getPermutationMap().isMinorIdentity() &&
46 "Expected minor identity map");
48 xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
52 if (xferOp.isDimInBounds(resultIdx))
56 int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx);
59 {xferOp.getIndices()[indicesIdx]});
64 if (maybeCstSum && maybeCstDimSz && *maybeCstSum <= *maybeCstDimSz)
67 b.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::sle,
72 inBoundsCond = b.
create<arith::AndIOp>(loc, inBoundsCond, cond);
115 if (xferOp.getTransferRank() == 0)
119 if (!xferOp.getPermutationMap().isMinorIdentity())
122 if (!xferOp.hasOutOfBoundsDim())
127 if (isa<scf::IfOp>(xferOp->getParentOp()))
142 if (memref::CastOp::areCastCompatible(aT, bT))
144 if (aT.getRank() != bT.getRank())
146 int64_t aOffset, bOffset;
150 aStrides.size() != bStrides.size())
156 resStrides(bT.getRank(), 0);
157 for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {
159 (aShape[idx] == bShape[idx]) ? aShape[idx] : ShapedType::kDynamic;
161 (aStrides[idx] == bStrides[idx]) ? aStrides[idx] : ShapedType::kDynamic;
163 resOffset = (aOffset == bOffset) ? aOffset : ShapedType::kDynamic;
165 resShape, aT.getElementType(),
173 MemRefType compatibleMemRefType) {
174 MemRefType sourceType = cast<MemRefType>(memref.
getType());
176 if (sourceType.getMemorySpace() != compatibleMemRefType.getMemorySpace()) {
178 sourceType.getShape(), sourceType.getElementType(),
179 sourceType.getLayout(), compatibleMemRefType.getMemorySpace());
180 res = b.
create<memref::MemorySpaceCastOp>(memref.
getLoc(), sourceType, res);
182 if (sourceType == compatibleMemRefType)
184 return b.
create<memref::CastOp>(memref.
getLoc(), compatibleMemRefType, res);
190 static std::pair<Value, Value>
194 int64_t memrefRank = xferOp.getShapedType().getRank();
196 assert(memrefRank == cast<MemRefType>(alloc.
getType()).getRank() &&
197 "Expected memref rank to match the alloc rank");
199 xferOp.getIndices().take_front(xferOp.getLeadingShapedRank());
201 sizes.append(leadingIndices.begin(), leadingIndices.end());
202 auto isaWrite = isa<vector::TransferWriteOp>(xferOp);
203 xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
205 Value dimMemRef = b.
create<memref::DimOp>(xferOp.getLoc(),
206 xferOp.getSource(), indicesIdx);
207 Value dimAlloc = b.
create<memref::DimOp>(loc, alloc, resultIdx);
208 Value index = xferOp.getIndices()[indicesIdx];
215 loc, index.getType(), maps[0],
ValueRange{dimMemRef, index, dimAlloc});
216 sizes.push_back(affineMin);
223 auto copySrc = b.
create<memref::SubViewOp>(
224 loc, isaWrite ? alloc : xferOp.getSource(), srcIndices, sizes, strides);
225 auto copyDest = b.
create<memref::SubViewOp>(
226 loc, isaWrite ? xferOp.getSource() : alloc, destIndices, sizes, strides);
227 return std::make_pair(copySrc, copyDest);
252 MemRefType compatibleMemRefType,
Value alloc) {
254 Value zero = b.
create<arith::ConstantIndexOp>(loc, 0);
255 Value memref = xferOp.getSource();
256 return b.
create<scf::IfOp>(
261 viewAndIndices.insert(viewAndIndices.end(), xferOp.getIndices().begin(),
262 xferOp.getIndices().end());
263 b.
create<scf::YieldOp>(loc, viewAndIndices);
272 rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
274 b.
create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
278 viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
280 b.
create<scf::YieldOp>(loc, viewAndIndices);
305 Value inBoundsCond, MemRefType compatibleMemRefType,
Value alloc) {
307 scf::IfOp fullPartialIfOp;
308 Value zero = b.
create<arith::ConstantIndexOp>(loc, 0);
309 Value memref = xferOp.getSource();
310 return b.
create<scf::IfOp>(
315 viewAndIndices.insert(viewAndIndices.end(), xferOp.getIndices().begin(),
316 xferOp.getIndices().end());
317 b.
create<scf::YieldOp>(loc, viewAndIndices);
321 Value vector = cast<VectorTransferOpInterface>(newXfer).getVector();
322 b.
create<memref::StoreOp>(
324 b.
create<vector::TypeCastOp>(
330 viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
332 b.
create<scf::YieldOp>(loc, viewAndIndices);
354 MemRefType compatibleMemRefType,
Value alloc) {
356 Value zero = b.
create<arith::ConstantIndexOp>(loc, 0);
357 Value memref = xferOp.getSource();
365 viewAndIndices.insert(viewAndIndices.end(),
366 xferOp.getIndices().begin(),
367 xferOp.getIndices().end());
368 b.
create<scf::YieldOp>(loc, viewAndIndices);
374 viewAndIndices.insert(viewAndIndices.end(),
375 xferOp.getTransferRank(), zero);
376 b.
create<scf::YieldOp>(loc, viewAndIndices);
395 vector::TransferWriteOp xferOp,
398 auto notInBounds = b.
create<arith::XOrIOp>(
399 loc, inBoundsCond, b.
create<arith::ConstantIntOp>(loc,
true, 1));
403 rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
405 b.
create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
423 vector::TransferWriteOp xferOp,
427 auto notInBounds = b.
create<arith::XOrIOp>(
428 loc, inBoundsCond, b.
create<arith::ConstantIntOp>(loc,
true, 1));
433 b.
create<vector::TypeCastOp>(
436 mapping.
map(xferOp.getVector(), load);
437 b.
clone(*xferOp.getOperation(), mapping);
452 if (!isa<scf::ForOp, affine::AffineForOp>(parent))
455 assert(scope &&
"Expected op to be inside automatic allocation scope");
527 if (
options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) {
529 xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
538 "Expected splitFullAndPartialTransferPrecondition to hold");
540 auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
541 auto xferWriteOp = dyn_cast<vector::TransferWriteOp>(xferOp.getOperation());
543 if (!(xferReadOp || xferWriteOp))
545 if (xferWriteOp && xferWriteOp.getMask())
547 if (xferReadOp && xferReadOp.getMask())
554 b, cast<VectorTransferOpInterface>(xferOp.getOperation()));
564 "AutomaticAllocationScope with >1 regions");
566 auto shape = xferOp.getVectorType().getShape();
567 Type elementType = xferOp.getVectorType().getElementType();
573 MemRefType compatibleMemRefType =
575 cast<MemRefType>(alloc.
getType()));
576 if (!compatibleMemRefType)
581 returnTypes[0] = compatibleMemRefType;
583 if (
auto xferReadOp =
584 dyn_cast<vector::TransferReadOp>(xferOp.getOperation())) {
586 scf::IfOp fullPartialIfOp =
587 options.vectorTransferSplit == VectorTransferSplit::VectorTransfer
590 compatibleMemRefType, alloc)
592 inBoundsCond, compatibleMemRefType,
595 *ifOp = fullPartialIfOp;
598 for (
unsigned i = 0, e = returnTypes.size(); i != e; ++i)
599 xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
602 xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
608 auto xferWriteOp = cast<vector::TransferWriteOp>(xferOp.getOperation());
612 b, xferWriteOp, returnTypes, inBoundsCond, compatibleMemRefType, alloc);
618 mapping.
map(xferWriteOp.getSource(), memrefAndIndices.front());
619 mapping.
map(xferWriteOp.getIndices(), memrefAndIndices.drop_front());
621 clone->
setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr);
625 if (
options.vectorTransferSplit == VectorTransferSplit::VectorTransfer)
638 struct VectorTransferFullPartialRewriter :
public RewritePattern {
639 using FilterConstraintType =
640 std::function<LogicalResult(VectorTransferOpInterface op)>;
642 explicit VectorTransferFullPartialRewriter(
645 FilterConstraintType filter =
646 [](VectorTransferOpInterface op) {
return success(); },
649 filter(std::move(filter)) {}
652 LogicalResult matchAndRewrite(
Operation *op,
657 FilterConstraintType filter;
662 LogicalResult VectorTransferFullPartialRewriter::matchAndRewrite(
664 auto xferOp = dyn_cast<VectorTransferOpInterface>(op);
666 failed(filter(xferOp)))
static llvm::ManagedStatic< PassManagerOptions > options
static scf::IfOp createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc)
Given an xferOp for which:
static void createFullPartialVectorTransferWrite(RewriterBase &b, vector::TransferWriteOp xferOp, Value inBoundsCond, Value alloc)
Given an xferOp for which:
static ValueRange getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp, TypeRange returnTypes, Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc)
Given an xferOp for which:
static scf::IfOp createFullPartialVectorTransferRead(RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc)
Given an xferOp for which:
static LogicalResult splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp)
Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds masking) fast path and a ...
static std::pair< Value, Value > createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp, Value alloc)
Operates under a scoped context to build the intersection between the view xferOp....
static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT)
Given two MemRefTypes aT and bT, return a MemRefType to which both can be cast.
static Value createInBoundsCond(RewriterBase &b, VectorTransferOpInterface xferOp)
Build the condition to ensure that a particular VectorTransferOpInterface is in-bounds.
static Operation * getAutomaticAllocationScope(Operation *op)
static Value castToCompatibleMemRefType(OpBuilder &b, Value memref, MemRefType compatibleMemRefType)
Casts the given memref to a compatible memref type.
Base type for affine expression.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineConstantExpr(int64_t constant)
IntegerAttr getI64IntegerAttr(int64_t value)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
A trait of region holding operations that define a new scope for automatic allocations,...
Operation is the basic unit of execution within MLIR.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
LogicalResult splitFullAndPartialTransfer(RewriterBase &b, VectorTransferOpInterface xferOp, VectorTransformsOptions options=VectorTransformsOptions(), scf::IfOp *ifOp=nullptr)
Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds masking) fastpath and a s...
void populateVectorTransferFullPartialPatterns(RewritePatternSet &patterns, const VectorTransformsOptions &options)
Populate patterns with the following patterns.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.