22#include "llvm/ADT/SmallVectorExtras.h"
28#include "llvm/ADT/STLExtras.h"
29#include "llvm/ADT/SmallVectorExtras.h"
31#define DEBUG_TYPE "vector-transfer-split"
39 VectorTransferOpInterface xferOp) {
40 assert(xferOp.getPermutationMap().isMinorIdentity() &&
41 "Expected minor identity map");
43 xferOp.zipResultAndIndexing([&](
int64_t resultIdx,
int64_t indicesIdx) {
47 if (xferOp.isDimInBounds(resultIdx))
51 int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx);
53 b, loc,
b.getAffineDimExpr(0) +
b.getAffineConstantExpr(vectorSize),
54 {xferOp.getIndices()[indicesIdx]});
59 if (maybeCstSum && maybeCstDimSz && *maybeCstSum <= *maybeCstDimSz)
62 arith::CmpIOp::create(
b, loc, arith::CmpIPredicate::sle,
67 inBoundsCond = arith::AndIOp::create(
b, loc, inBoundsCond, cond);
108splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp) {
110 if (xferOp.getTransferRank() == 0)
114 if (!xferOp.getPermutationMap().isMinorIdentity())
117 if (!xferOp.hasOutOfBoundsDim())
122 if (isa<scf::IfOp>(xferOp->getParentOp()))
136static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
137 if (memref::CastOp::areCastCompatible(aT, bT))
139 if (aT.getRank() != bT.getRank())
143 if (
failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
144 failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
145 aStrides.size() != bStrides.size())
151 resStrides(bT.getRank(), 0);
152 for (
int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {
154 (aShape[idx] == bShape[idx]) ? aShape[idx] : ShapedType::kDynamic;
156 (aStrides[idx] == bStrides[idx]) ? aStrides[idx] : ShapedType::kDynamic;
158 resOffset = (aOffset == bOffset) ? aOffset : ShapedType::kDynamic;
159 return MemRefType::get(
160 resShape, aT.getElementType(),
161 StridedLayoutAttr::get(aT.getContext(), resOffset, resStrides));
168 MemRefType compatibleMemRefType) {
169 MemRefType sourceType = cast<MemRefType>(
memref.getType());
171 if (sourceType.getMemorySpace() != compatibleMemRefType.getMemorySpace()) {
172 sourceType = MemRefType::get(
173 sourceType.getShape(), sourceType.getElementType(),
174 sourceType.getLayout(), compatibleMemRefType.getMemorySpace());
176 memref::MemorySpaceCastOp::create(
b,
memref.getLoc(), sourceType, res);
178 if (sourceType == compatibleMemRefType)
180 return memref::CastOp::create(
b,
memref.getLoc(), compatibleMemRefType, res);
186static std::pair<Value, Value>
187createSubViewIntersection(
RewriterBase &
b, VectorTransferOpInterface xferOp,
190 int64_t memrefRank = xferOp.getShapedType().getRank();
192 assert(memrefRank == cast<MemRefType>(alloc.
getType()).getRank() &&
193 "Expected memref rank to match the alloc rank");
195 xferOp.getIndices().take_front(xferOp.getLeadingShapedRank());
197 sizes.append(leadingIndices.begin(), leadingIndices.end());
198 auto isaWrite = isa<vector::TransferWriteOp>(xferOp);
199 xferOp.zipResultAndIndexing([&](
int64_t resultIdx,
int64_t indicesIdx) {
202 memref::DimOp::create(
b, xferOp.getLoc(), xferOp.getBase(), indicesIdx);
203 Value dimAlloc = memref::DimOp::create(
b, loc, alloc, resultIdx);
204 Value index = xferOp.getIndices()[indicesIdx];
211 affine::AffineMinOp::create(
b, loc,
index.getType(), maps[0],
213 sizes.push_back(affineMin);
220 auto copySrc = memref::SubViewOp::create(
221 b, loc, isaWrite ? alloc : xferOp.getBase(), srcIndices, sizes, strides);
222 auto copyDest = memref::SubViewOp::create(
223 b, loc, isaWrite ? xferOp.getBase() : alloc, destIndices, sizes, strides);
224 return std::make_pair(copySrc, copyDest);
247createFullPartialLinalgCopy(
RewriterBase &
b, vector::TransferReadOp xferOp,
249 MemRefType compatibleMemRefType,
Value alloc) {
253 return scf::IfOp::create(
254 b, loc, inBoundsCond,
256 Value res = castToCompatibleMemRefType(
b,
memref, compatibleMemRefType);
258 llvm::append_range(viewAndIndices, xferOp.getIndices());
259 scf::YieldOp::create(
b, loc, viewAndIndices);
262 linalg::FillOp::create(
b, loc,
ValueRange{xferOp.getPadding()},
267 std::pair<Value, Value> copyArgs = createSubViewIntersection(
268 rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
270 memref::CopyOp::create(
b, loc, copyArgs.first, copyArgs.second);
272 castToCompatibleMemRefType(
b, alloc, compatibleMemRefType);
274 viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
276 scf::YieldOp::create(
b, loc, viewAndIndices);
299static scf::IfOp createFullPartialVectorTransferRead(
301 Value inBoundsCond, MemRefType compatibleMemRefType,
Value alloc) {
303 scf::IfOp fullPartialIfOp;
306 return scf::IfOp::create(
307 b, loc, inBoundsCond,
309 Value res = castToCompatibleMemRefType(
b,
memref, compatibleMemRefType);
311 llvm::append_range(viewAndIndices, xferOp.getIndices());
312 scf::YieldOp::create(
b, loc, viewAndIndices);
315 Operation *newXfer =
b.clone(*xferOp.getOperation());
316 Value vector = cast<VectorTransferOpInterface>(newXfer).getVector();
317 memref::StoreOp::create(
319 vector::TypeCastOp::create(
320 b, loc, MemRefType::get({},
vector.getType()), alloc));
323 castToCompatibleMemRefType(
b, alloc, compatibleMemRefType);
325 viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
327 scf::YieldOp::create(
b, loc, viewAndIndices);
347getLocationToWriteFullVec(
RewriterBase &
b, vector::TransferWriteOp xferOp,
349 MemRefType compatibleMemRefType,
Value alloc) {
353 return scf::IfOp::create(
354 b, loc, inBoundsCond,
357 castToCompatibleMemRefType(
b,
memref, compatibleMemRefType);
359 llvm::append_range(viewAndIndices, xferOp.getIndices());
360 scf::YieldOp::create(
b, loc, viewAndIndices);
364 castToCompatibleMemRefType(
b, alloc, compatibleMemRefType);
366 viewAndIndices.insert(viewAndIndices.end(),
367 xferOp.getTransferRank(), zero);
368 scf::YieldOp::create(
b, loc, viewAndIndices);
387 vector::TransferWriteOp xferOp,
390 auto notInBounds = arith::XOrIOp::create(
394 std::pair<Value, Value> copyArgs = createSubViewIntersection(
395 rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
397 memref::CopyOp::create(
b, loc, copyArgs.first, copyArgs.second);
414static void createFullPartialVectorTransferWrite(
RewriterBase &
b,
415 vector::TransferWriteOp xferOp,
419 auto notInBounds = arith::XOrIOp::create(
425 vector::TypeCastOp::create(
426 b, loc, MemRefType::get({}, xferOp.getVector().
getType()), alloc),
428 mapping.
map(xferOp.getVector(),
load);
429 b.clone(*xferOp.getOperation(), mapping);
444 if (!isa<scf::ForOp, affine::AffineForOp>(parent))
447 assert(scope &&
"Expected op to be inside automatic allocation scope");
511LogicalResult mlir::vector::splitFullAndPartialTransfer(
514 if (
options.vectorTransferSplit == VectorTransferSplit::None)
518 auto inBoundsAttr =
b.getBoolArrayAttr(bools);
519 if (
options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) {
520 b.modifyOpInPlace(xferOp, [&]() {
521 xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
529 assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) &&
530 "Expected splitFullAndPartialTransferPrecondition to hold");
532 auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
533 auto xferWriteOp = dyn_cast<vector::TransferWriteOp>(xferOp.getOperation());
535 if (!(xferReadOp || xferWriteOp))
537 if (xferWriteOp && xferWriteOp.getMask())
539 if (xferReadOp && xferReadOp.getMask())
544 b.setInsertionPoint(xferOp);
546 b, cast<VectorTransferOpInterface>(xferOp.getOperation()));
554 Operation *scope = getAutomaticAllocationScope(xferOp);
556 "AutomaticAllocationScope with >1 regions");
558 auto shape = xferOp.getVectorType().getShape();
559 Type elementType = xferOp.getVectorType().getElementType();
560 alloc = memref::AllocaOp::create(
b, scope->
getLoc(),
561 MemRefType::get(
shape, elementType),
565 MemRefType compatibleMemRefType =
566 getCastCompatibleMemRefType(cast<MemRefType>(xferOp.getShapedType()),
567 cast<MemRefType>(alloc.
getType()));
568 if (!compatibleMemRefType)
573 returnTypes[0] = compatibleMemRefType;
575 if (
auto xferReadOp =
576 dyn_cast<vector::TransferReadOp>(xferOp.getOperation())) {
578 scf::IfOp fullPartialIfOp =
579 options.vectorTransferSplit == VectorTransferSplit::VectorTransfer
580 ? createFullPartialVectorTransferRead(
b, xferReadOp, returnTypes,
582 compatibleMemRefType, alloc)
583 : createFullPartialLinalgCopy(
b, xferReadOp, returnTypes,
584 inBoundsCond, compatibleMemRefType,
587 *ifOp = fullPartialIfOp;
590 for (
unsigned i = 0, e = returnTypes.size(); i != e; ++i)
591 xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
593 b.modifyOpInPlace(xferOp, [&]() {
594 xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
600 auto xferWriteOp = cast<vector::TransferWriteOp>(xferOp.getOperation());
603 auto memrefAndIndices = getLocationToWriteFullVec(
604 b, xferWriteOp, returnTypes, inBoundsCond, compatibleMemRefType, alloc);
610 mapping.
map(xferWriteOp.getBase(), memrefAndIndices.front());
611 mapping.
map(xferWriteOp.getIndices(), memrefAndIndices.drop_front());
612 auto *
clone =
b.clone(*xferWriteOp, mapping);
613 clone->setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr);
617 if (
options.vectorTransferSplit == VectorTransferSplit::VectorTransfer)
618 createFullPartialVectorTransferWrite(
b, xferWriteOp, inBoundsCond, alloc);
620 createFullPartialLinalgCopy(
b, xferWriteOp, inBoundsCond, alloc);
631 using FilterConstraintType =
632 std::function<LogicalResult(VectorTransferOpInterface op)>;
634 explicit VectorTransferFullPartialRewriter(
635 MLIRContext *context,
636 VectorTransformsOptions options = VectorTransformsOptions(),
637 FilterConstraintType filter =
638 [](VectorTransferOpInterface op) {
return success(); },
639 PatternBenefit benefit = 1)
640 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options),
641 filter(std::move(filter)) {}
644 LogicalResult matchAndRewrite(Operation *op,
645 PatternRewriter &rewriter)
const override;
648 VectorTransformsOptions options;
649 FilterConstraintType filter;
654LogicalResult VectorTransferFullPartialRewriter::matchAndRewrite(
656 auto xferOp = dyn_cast<VectorTransferOpInterface>(op);
657 if (!xferOp ||
failed(splitFullAndPartialTransferPrecondition(xferOp)) ||
660 return splitFullAndPartialTransfer(rewriter, xferOp,
options);
static llvm::ManagedStatic< PassManagerOptions > options
static Value createInBoundsCond(RewriterBase &b, VectorTransferOpInterface xferOp)
Build the condition to ensure that a particular VectorTransferOpInterface is in-bounds.
Base type for affine expression.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
This class represents a single result from folding an operation.
A trait of region holding operations that define a new scope for automatic allocations,...
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
void populateVectorTransferFullPartialPatterns(RewritePatternSet &patterns, const VectorTransformsOptions &options)
Populate patterns with the following patterns.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.