15 #include <type_traits>
29 #include "llvm/ADT/DenseSet.h"
30 #include "llvm/ADT/MapVector.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/Support/CommandLine.h"
33 #include "llvm/Support/Debug.h"
34 #include "llvm/Support/raw_ostream.h"
36 #define DEBUG_TYPE "vector-transfer-split"
44 VectorTransferOpInterface xferOp) {
45 assert(xferOp.getPermutationMap().isMinorIdentity() &&
46 "Expected minor identity map");
48 xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
52 if (xferOp.isDimInBounds(resultIdx))
56 int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx);
59 {xferOp.getIndices()[indicesIdx]});
64 if (maybeCstSum && maybeCstDimSz && *maybeCstSum <= *maybeCstDimSz)
67 b.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::sle,
72 inBoundsCond = b.
create<arith::AndIOp>(loc, inBoundsCond, cond);
115 if (xferOp.getTransferRank() == 0)
119 if (!xferOp.getPermutationMap().isMinorIdentity())
122 if (!xferOp.hasOutOfBoundsDim())
127 if (isa<scf::IfOp>(xferOp->getParentOp()))
142 if (memref::CastOp::areCastCompatible(aT, bT))
144 if (aT.getRank() != bT.getRank())
146 int64_t aOffset, bOffset;
148 if (failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
149 failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
150 aStrides.size() != bStrides.size())
156 resStrides(bT.getRank(), 0);
157 for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {
159 (aShape[idx] == bShape[idx]) ? aShape[idx] : ShapedType::kDynamic;
161 (aStrides[idx] == bStrides[idx]) ? aStrides[idx] : ShapedType::kDynamic;
163 resOffset = (aOffset == bOffset) ? aOffset : ShapedType::kDynamic;
165 resShape, aT.getElementType(),
173 MemRefType compatibleMemRefType) {
174 MemRefType sourceType = cast<MemRefType>(memref.
getType());
176 if (sourceType.getMemorySpace() != compatibleMemRefType.getMemorySpace()) {
178 sourceType.getShape(), sourceType.getElementType(),
179 sourceType.getLayout(), compatibleMemRefType.getMemorySpace());
180 res = b.
create<memref::MemorySpaceCastOp>(memref.
getLoc(), sourceType, res);
182 if (sourceType == compatibleMemRefType)
184 return b.
create<memref::CastOp>(memref.
getLoc(), compatibleMemRefType, res);
190 static std::pair<Value, Value>
194 int64_t memrefRank = xferOp.getShapedType().getRank();
196 assert(memrefRank == cast<MemRefType>(alloc.
getType()).getRank() &&
197 "Expected memref rank to match the alloc rank");
199 xferOp.getIndices().take_front(xferOp.getLeadingShapedRank());
201 sizes.append(leadingIndices.begin(), leadingIndices.end());
202 auto isaWrite = isa<vector::TransferWriteOp>(xferOp);
203 xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
205 Value dimMemRef = b.
create<memref::DimOp>(xferOp.getLoc(),
206 xferOp.getSource(), indicesIdx);
207 Value dimAlloc = b.
create<memref::DimOp>(loc, alloc, resultIdx);
208 Value index = xferOp.getIndices()[indicesIdx];
215 loc, index.getType(), maps[0],
ValueRange{dimMemRef, index, dimAlloc});
216 sizes.push_back(affineMin);
223 auto copySrc = b.
create<memref::SubViewOp>(
224 loc, isaWrite ? alloc : xferOp.getSource(), srcIndices, sizes, strides);
225 auto copyDest = b.
create<memref::SubViewOp>(
226 loc, isaWrite ? xferOp.getSource() : alloc, destIndices, sizes, strides);
227 return std::make_pair(copySrc, copyDest);
252 MemRefType compatibleMemRefType,
Value alloc) {
254 Value zero = b.
create<arith::ConstantIndexOp>(loc, 0);
255 Value memref = xferOp.getSource();
256 return b.
create<scf::IfOp>(
261 llvm::append_range(viewAndIndices, xferOp.getIndices());
262 b.
create<scf::YieldOp>(loc, viewAndIndices);
271 rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
273 b.
create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
277 viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
279 b.
create<scf::YieldOp>(loc, viewAndIndices);
304 Value inBoundsCond, MemRefType compatibleMemRefType,
Value alloc) {
306 scf::IfOp fullPartialIfOp;
307 Value zero = b.
create<arith::ConstantIndexOp>(loc, 0);
308 Value memref = xferOp.getSource();
309 return b.
create<scf::IfOp>(
314 llvm::append_range(viewAndIndices, xferOp.getIndices());
315 b.
create<scf::YieldOp>(loc, viewAndIndices);
319 Value vector = cast<VectorTransferOpInterface>(newXfer).getVector();
320 b.
create<memref::StoreOp>(
322 b.
create<vector::TypeCastOp>(
328 viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
330 b.
create<scf::YieldOp>(loc, viewAndIndices);
352 MemRefType compatibleMemRefType,
Value alloc) {
354 Value zero = b.
create<arith::ConstantIndexOp>(loc, 0);
355 Value memref = xferOp.getSource();
363 llvm::append_range(viewAndIndices, xferOp.getIndices());
364 b.
create<scf::YieldOp>(loc, viewAndIndices);
370 viewAndIndices.insert(viewAndIndices.end(),
371 xferOp.getTransferRank(), zero);
372 b.
create<scf::YieldOp>(loc, viewAndIndices);
391 vector::TransferWriteOp xferOp,
394 auto notInBounds = b.
create<arith::XOrIOp>(
395 loc, inBoundsCond, b.
create<arith::ConstantIntOp>(loc,
true, 1));
399 rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
401 b.
create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
419 vector::TransferWriteOp xferOp,
423 auto notInBounds = b.
create<arith::XOrIOp>(
424 loc, inBoundsCond, b.
create<arith::ConstantIntOp>(loc,
true, 1));
429 b.
create<vector::TypeCastOp>(
432 mapping.
map(xferOp.getVector(), load);
433 b.
clone(*xferOp.getOperation(), mapping);
448 if (!isa<scf::ForOp, affine::AffineForOp>(parent))
451 assert(scope &&
"Expected op to be inside automatic allocation scope");
523 if (
options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) {
525 xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
534 "Expected splitFullAndPartialTransferPrecondition to hold");
536 auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
537 auto xferWriteOp = dyn_cast<vector::TransferWriteOp>(xferOp.getOperation());
539 if (!(xferReadOp || xferWriteOp))
541 if (xferWriteOp && xferWriteOp.getMask())
543 if (xferReadOp && xferReadOp.getMask())
550 b, cast<VectorTransferOpInterface>(xferOp.getOperation()));
560 "AutomaticAllocationScope with >1 regions");
562 auto shape = xferOp.getVectorType().getShape();
563 Type elementType = xferOp.getVectorType().getElementType();
569 MemRefType compatibleMemRefType =
571 cast<MemRefType>(alloc.
getType()));
572 if (!compatibleMemRefType)
577 returnTypes[0] = compatibleMemRefType;
579 if (
auto xferReadOp =
580 dyn_cast<vector::TransferReadOp>(xferOp.getOperation())) {
582 scf::IfOp fullPartialIfOp =
583 options.vectorTransferSplit == VectorTransferSplit::VectorTransfer
586 compatibleMemRefType, alloc)
588 inBoundsCond, compatibleMemRefType,
591 *ifOp = fullPartialIfOp;
594 for (
unsigned i = 0, e = returnTypes.size(); i != e; ++i)
595 xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
598 xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
604 auto xferWriteOp = cast<vector::TransferWriteOp>(xferOp.getOperation());
608 b, xferWriteOp, returnTypes, inBoundsCond, compatibleMemRefType, alloc);
614 mapping.
map(xferWriteOp.getSource(), memrefAndIndices.front());
615 mapping.
map(xferWriteOp.getIndices(), memrefAndIndices.drop_front());
617 clone->
setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr);
621 if (
options.vectorTransferSplit == VectorTransferSplit::VectorTransfer)
634 struct VectorTransferFullPartialRewriter :
public RewritePattern {
635 using FilterConstraintType =
636 std::function<LogicalResult(VectorTransferOpInterface op)>;
638 explicit VectorTransferFullPartialRewriter(
641 FilterConstraintType filter =
642 [](VectorTransferOpInterface op) {
return success(); },
645 filter(std::move(filter)) {}
648 LogicalResult matchAndRewrite(
Operation *op,
653 FilterConstraintType filter;
658 LogicalResult VectorTransferFullPartialRewriter::matchAndRewrite(
660 auto xferOp = dyn_cast<VectorTransferOpInterface>(op);
662 failed(filter(xferOp)))
static llvm::ManagedStatic< PassManagerOptions > options
static scf::IfOp createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc)
Given an xferOp for which:
static void createFullPartialVectorTransferWrite(RewriterBase &b, vector::TransferWriteOp xferOp, Value inBoundsCond, Value alloc)
Given an xferOp for which:
static ValueRange getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp, TypeRange returnTypes, Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc)
Given an xferOp for which:
static scf::IfOp createFullPartialVectorTransferRead(RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc)
Given an xferOp for which:
static LogicalResult splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp)
Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds masking) fast path and a ...
static std::pair< Value, Value > createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp, Value alloc)
Operates under a scoped context to build the intersection between the view xferOp....
static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT)
Given two MemRefTypes aT and bT, return a MemRefType to which both can be cast.
static Value createInBoundsCond(RewriterBase &b, VectorTransferOpInterface xferOp)
Build the condition to ensure that a particular VectorTransferOpInterface is in-bounds.
static Operation * getAutomaticAllocationScope(Operation *op)
static Value castToCompatibleMemRefType(OpBuilder &b, Value memref, MemRefType compatibleMemRefType)
Casts the given memref to a compatible memref type.
Base type for affine expression.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineConstantExpr(int64_t constant)
IntegerAttr getI64IntegerAttr(int64_t value)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
A trait of region holding operations that define a new scope for automatic allocations,...
Operation is the basic unit of execution within MLIR.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
LogicalResult splitFullAndPartialTransfer(RewriterBase &b, VectorTransferOpInterface xferOp, VectorTransformsOptions options=VectorTransformsOptions(), scf::IfOp *ifOp=nullptr)
Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds masking) fastpath and a s...
void populateVectorTransferFullPartialPatterns(RewritePatternSet &patterns, const VectorTransformsOptions &options)
Populate patterns with the following patterns.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.