27#include "llvm/ADT/STLExtras.h"
29#define DEBUG_TYPE "vector-transfer-split"
37 VectorTransferOpInterface xferOp) {
38 assert(xferOp.getPermutationMap().isMinorIdentity() &&
39 "Expected minor identity map");
41 xferOp.zipResultAndIndexing([&](
int64_t resultIdx,
int64_t indicesIdx) {
45 if (xferOp.isDimInBounds(resultIdx))
49 int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx);
51 b, loc,
b.getAffineDimExpr(0) +
b.getAffineConstantExpr(vectorSize),
52 {xferOp.getIndices()[indicesIdx]});
57 if (maybeCstSum && maybeCstDimSz && *maybeCstSum <= *maybeCstDimSz)
60 arith::CmpIOp::create(
b, loc, arith::CmpIPredicate::sle,
65 inBoundsCond = arith::AndIOp::create(
b, loc, inBoundsCond, cond);
106splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp) {
108 if (xferOp.getTransferRank() == 0)
112 if (!xferOp.getPermutationMap().isMinorIdentity())
115 if (!xferOp.hasOutOfBoundsDim())
120 if (isa<scf::IfOp>(xferOp->getParentOp()))
134static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
135 if (memref::CastOp::areCastCompatible(aT, bT))
137 if (aT.getRank() != bT.getRank())
141 if (
failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
142 failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
143 aStrides.size() != bStrides.size())
149 resStrides(bT.getRank(), 0);
150 for (
int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {
152 (aShape[idx] == bShape[idx]) ? aShape[idx] : ShapedType::kDynamic;
154 (aStrides[idx] == bStrides[idx]) ? aStrides[idx] : ShapedType::kDynamic;
156 resOffset = (aOffset == bOffset) ? aOffset : ShapedType::kDynamic;
157 return MemRefType::get(
158 resShape, aT.getElementType(),
159 StridedLayoutAttr::get(aT.getContext(), resOffset, resStrides));
166 MemRefType compatibleMemRefType) {
167 MemRefType sourceType = cast<MemRefType>(
memref.getType());
169 if (sourceType.getMemorySpace() != compatibleMemRefType.getMemorySpace()) {
170 sourceType = MemRefType::get(
171 sourceType.getShape(), sourceType.getElementType(),
172 sourceType.getLayout(), compatibleMemRefType.getMemorySpace());
174 memref::MemorySpaceCastOp::create(
b,
memref.getLoc(), sourceType, res);
176 if (sourceType == compatibleMemRefType)
178 return memref::CastOp::create(
b,
memref.getLoc(), compatibleMemRefType, res);
184static std::pair<Value, Value>
185createSubViewIntersection(
RewriterBase &
b, VectorTransferOpInterface xferOp,
188 int64_t memrefRank = xferOp.getShapedType().getRank();
190 assert(memrefRank == cast<MemRefType>(alloc.
getType()).getRank() &&
191 "Expected memref rank to match the alloc rank");
193 xferOp.getIndices().take_front(xferOp.getLeadingShapedRank());
195 sizes.append(leadingIndices.begin(), leadingIndices.end());
196 auto isaWrite = isa<vector::TransferWriteOp>(xferOp);
197 xferOp.zipResultAndIndexing([&](
int64_t resultIdx,
int64_t indicesIdx) {
200 memref::DimOp::create(
b, xferOp.getLoc(), xferOp.getBase(), indicesIdx);
201 Value dimAlloc = memref::DimOp::create(
b, loc, alloc, resultIdx);
202 Value index = xferOp.getIndices()[indicesIdx];
209 affine::AffineMinOp::create(
b, loc,
index.getType(), maps[0],
211 sizes.push_back(affineMin);
218 auto copySrc = memref::SubViewOp::create(
219 b, loc, isaWrite ? alloc : xferOp.getBase(), srcIndices, sizes, strides);
220 auto copyDest = memref::SubViewOp::create(
221 b, loc, isaWrite ? xferOp.getBase() : alloc, destIndices, sizes, strides);
222 return std::make_pair(copySrc, copyDest);
245createFullPartialLinalgCopy(
RewriterBase &
b, vector::TransferReadOp xferOp,
247 MemRefType compatibleMemRefType,
Value alloc) {
251 return scf::IfOp::create(
252 b, loc, inBoundsCond,
254 Value res = castToCompatibleMemRefType(
b,
memref, compatibleMemRefType);
256 llvm::append_range(viewAndIndices, xferOp.getIndices());
257 scf::YieldOp::create(
b, loc, viewAndIndices);
260 linalg::FillOp::create(
b, loc,
ValueRange{xferOp.getPadding()},
265 std::pair<Value, Value> copyArgs = createSubViewIntersection(
266 rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
268 memref::CopyOp::create(
b, loc, copyArgs.first, copyArgs.second);
270 castToCompatibleMemRefType(
b, alloc, compatibleMemRefType);
272 viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
274 scf::YieldOp::create(
b, loc, viewAndIndices);
297static scf::IfOp createFullPartialVectorTransferRead(
299 Value inBoundsCond, MemRefType compatibleMemRefType,
Value alloc) {
301 scf::IfOp fullPartialIfOp;
304 return scf::IfOp::create(
305 b, loc, inBoundsCond,
307 Value res = castToCompatibleMemRefType(
b,
memref, compatibleMemRefType);
309 llvm::append_range(viewAndIndices, xferOp.getIndices());
310 scf::YieldOp::create(
b, loc, viewAndIndices);
313 Operation *newXfer =
b.clone(*xferOp.getOperation());
314 Value vector = cast<VectorTransferOpInterface>(newXfer).getVector();
315 memref::StoreOp::create(
317 vector::TypeCastOp::create(
318 b, loc, MemRefType::get({},
vector.getType()), alloc));
321 castToCompatibleMemRefType(
b, alloc, compatibleMemRefType);
323 viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
325 scf::YieldOp::create(
b, loc, viewAndIndices);
345getLocationToWriteFullVec(
RewriterBase &
b, vector::TransferWriteOp xferOp,
347 MemRefType compatibleMemRefType,
Value alloc) {
351 return scf::IfOp::create(
352 b, loc, inBoundsCond,
355 castToCompatibleMemRefType(
b,
memref, compatibleMemRefType);
357 llvm::append_range(viewAndIndices, xferOp.getIndices());
358 scf::YieldOp::create(
b, loc, viewAndIndices);
362 castToCompatibleMemRefType(
b, alloc, compatibleMemRefType);
364 viewAndIndices.insert(viewAndIndices.end(),
365 xferOp.getTransferRank(), zero);
366 scf::YieldOp::create(
b, loc, viewAndIndices);
385 vector::TransferWriteOp xferOp,
388 auto notInBounds = arith::XOrIOp::create(
392 std::pair<Value, Value> copyArgs = createSubViewIntersection(
393 rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
395 memref::CopyOp::create(
b, loc, copyArgs.first, copyArgs.second);
412static void createFullPartialVectorTransferWrite(
RewriterBase &
b,
413 vector::TransferWriteOp xferOp,
417 auto notInBounds = arith::XOrIOp::create(
423 vector::TypeCastOp::create(
424 b, loc, MemRefType::get({}, xferOp.getVector().
getType()), alloc),
426 mapping.
map(xferOp.getVector(),
load);
427 b.clone(*xferOp.getOperation(), mapping);
442 if (!isa<scf::ForOp, affine::AffineForOp>(parent))
445 assert(scope &&
"Expected op to be inside automatic allocation scope");
509LogicalResult mlir::vector::splitFullAndPartialTransfer(
512 if (
options.vectorTransferSplit == VectorTransferSplit::None)
516 auto inBoundsAttr =
b.getBoolArrayAttr(bools);
517 if (
options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) {
518 b.modifyOpInPlace(xferOp, [&]() {
519 xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
527 assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) &&
528 "Expected splitFullAndPartialTransferPrecondition to hold");
530 auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
531 auto xferWriteOp = dyn_cast<vector::TransferWriteOp>(xferOp.getOperation());
533 if (!(xferReadOp || xferWriteOp))
535 if (xferWriteOp && xferWriteOp.getMask())
537 if (xferReadOp && xferReadOp.getMask())
542 b.setInsertionPoint(xferOp);
544 b, cast<VectorTransferOpInterface>(xferOp.getOperation()));
552 Operation *scope = getAutomaticAllocationScope(xferOp);
554 "AutomaticAllocationScope with >1 regions");
556 auto shape = xferOp.getVectorType().getShape();
557 Type elementType = xferOp.getVectorType().getElementType();
558 alloc = memref::AllocaOp::create(
b, scope->
getLoc(),
559 MemRefType::get(
shape, elementType),
563 MemRefType compatibleMemRefType =
564 getCastCompatibleMemRefType(cast<MemRefType>(xferOp.getShapedType()),
565 cast<MemRefType>(alloc.
getType()));
566 if (!compatibleMemRefType)
571 returnTypes[0] = compatibleMemRefType;
573 if (
auto xferReadOp =
574 dyn_cast<vector::TransferReadOp>(xferOp.getOperation())) {
576 scf::IfOp fullPartialIfOp =
577 options.vectorTransferSplit == VectorTransferSplit::VectorTransfer
578 ? createFullPartialVectorTransferRead(
b, xferReadOp, returnTypes,
580 compatibleMemRefType, alloc)
581 : createFullPartialLinalgCopy(
b, xferReadOp, returnTypes,
582 inBoundsCond, compatibleMemRefType,
585 *ifOp = fullPartialIfOp;
588 for (
unsigned i = 0, e = returnTypes.size(); i != e; ++i)
589 xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
591 b.modifyOpInPlace(xferOp, [&]() {
592 xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
598 auto xferWriteOp = cast<vector::TransferWriteOp>(xferOp.getOperation());
601 auto memrefAndIndices = getLocationToWriteFullVec(
602 b, xferWriteOp, returnTypes, inBoundsCond, compatibleMemRefType, alloc);
608 mapping.
map(xferWriteOp.getBase(), memrefAndIndices.front());
609 mapping.
map(xferWriteOp.getIndices(), memrefAndIndices.drop_front());
610 auto *
clone =
b.clone(*xferWriteOp, mapping);
611 clone->setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr);
615 if (
options.vectorTransferSplit == VectorTransferSplit::VectorTransfer)
616 createFullPartialVectorTransferWrite(
b, xferWriteOp, inBoundsCond, alloc);
618 createFullPartialLinalgCopy(
b, xferWriteOp, inBoundsCond, alloc);
629 using FilterConstraintType =
630 std::function<LogicalResult(VectorTransferOpInterface op)>;
632 explicit VectorTransferFullPartialRewriter(
633 MLIRContext *context,
634 VectorTransformsOptions options = VectorTransformsOptions(),
635 FilterConstraintType filter =
636 [](VectorTransferOpInterface op) {
return success(); },
637 PatternBenefit benefit = 1)
638 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options),
639 filter(std::move(filter)) {}
642 LogicalResult matchAndRewrite(Operation *op,
643 PatternRewriter &rewriter)
const override;
646 VectorTransformsOptions options;
647 FilterConstraintType filter;
652LogicalResult VectorTransferFullPartialRewriter::matchAndRewrite(
654 auto xferOp = dyn_cast<VectorTransferOpInterface>(op);
655 if (!xferOp ||
failed(splitFullAndPartialTransferPrecondition(xferOp)) ||
658 return splitFullAndPartialTransfer(rewriter, xferOp,
options);
static llvm::ManagedStatic< PassManagerOptions > options
static Value createInBoundsCond(RewriterBase &b, VectorTransferOpInterface xferOp)
Build the condition to ensure that a particular VectorTransferOpInterface is in-bounds.
Base type for affine expression.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
This class represents a single result from folding an operation.
A trait of region holding operations that define a new scope for automatic allocations,...
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
void populateVectorTransferFullPartialPatterns(RewritePatternSet &patterns, const VectorTransformsOptions &options)
Populate patterns with the following patterns.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.