15 #include <type_traits>
29 #include "llvm/ADT/DenseSet.h"
30 #include "llvm/ADT/MapVector.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/Support/CommandLine.h"
33 #include "llvm/Support/Debug.h"
34 #include "llvm/Support/raw_ostream.h"
36 #define DEBUG_TYPE "vector-transfer-split"
45 if (affineApplyOp.getAffineMap().isSingleConstant())
46 return affineApplyOp.getAffineMap().getSingleConstantResult();
56 if (maybeCstV && maybeCstUb && *maybeCstV < *maybeCstUb)
58 return b.
create<arith::CmpIOp>(v.
getLoc(), arith::CmpIPredicate::sle, v, ub);
64 VectorTransferOpInterface xferOp) {
65 assert(xferOp.permutation_map().isMinorIdentity() &&
66 "Expected minor identity map");
68 xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
72 if (xferOp.isDimInBounds(resultIdx))
76 int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx);
87 inBoundsCond = b.
create<arith::AndIOp>(loc, inBoundsCond, cond);
130 if (xferOp.getTransferRank() == 0)
134 if (!xferOp.permutation_map().isMinorIdentity())
137 if (!xferOp.hasOutOfBoundsDim())
142 if (isa<scf::IfOp>(xferOp->getParentOp()))
157 if (memref::CastOp::areCastCompatible(aT, bT))
159 if (aT.getRank() != bT.getRank())
161 int64_t aOffset, bOffset;
165 aStrides.size() != bStrides.size())
171 resStrides(bT.getRank(), 0);
172 for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {
174 (aShape[idx] == bShape[idx]) ? aShape[idx] : ShapedType::kDynamic;
176 (aStrides[idx] == bStrides[idx]) ? aStrides[idx] : ShapedType::kDynamic;
178 resOffset = (aOffset == bOffset) ? aOffset : ShapedType::kDynamic;
179 return MemRefType::get(
180 resShape, aT.getElementType(),
181 StridedLayoutAttr::get(aT.getContext(), resOffset, resStrides));
187 static std::pair<Value, Value>
191 int64_t memrefRank = xferOp.getShapedType().getRank();
193 assert(memrefRank == alloc.
getType().
cast<MemRefType>().getRank() &&
194 "Expected memref rank to match the alloc rank");
196 xferOp.indices().take_front(xferOp.getLeadingShapedRank());
198 sizes.append(leadingIndices.begin(), leadingIndices.end());
199 auto isaWrite = isa<vector::TransferWriteOp>(xferOp);
200 xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
203 xferOp.source(), indicesIdx);
204 Value dimAlloc = b.
create<memref::DimOp>(loc, alloc, resultIdx);
205 Value index = xferOp.indices()[indicesIdx];
212 loc, index.getType(), maps[0],
ValueRange{dimMemRef, index, dimAlloc});
213 sizes.push_back(affineMin);
220 auto copySrc = b.
create<memref::SubViewOp>(
221 loc, isaWrite ? alloc : xferOp.source(), srcIndices, sizes, strides);
222 auto copyDest = b.
create<memref::SubViewOp>(
223 loc, isaWrite ? xferOp.source() : alloc, destIndices, sizes, strides);
224 return std::make_pair(copySrc, copyDest);
248 MemRefType compatibleMemRefType,
Value alloc) {
250 Value zero = b.
create<arith::ConstantIndexOp>(loc, 0);
251 Value memref = xferOp.getSource();
252 return b.
create<scf::IfOp>(
256 if (compatibleMemRefType != xferOp.getShapedType())
257 res = b.
create<memref::CastOp>(loc, compatibleMemRefType, memref);
259 viewAndIndices.insert(viewAndIndices.end(), xferOp.getIndices().begin(),
260 xferOp.getIndices().end());
261 b.
create<scf::YieldOp>(loc, viewAndIndices);
270 rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
272 b.
create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
274 b.
create<memref::CastOp>(loc, compatibleMemRefType, alloc);
276 viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
278 b.
create<scf::YieldOp>(loc, viewAndIndices);
302 Value inBoundsCond, MemRefType compatibleMemRefType,
Value alloc) {
304 scf::IfOp fullPartialIfOp;
305 Value zero = b.
create<arith::ConstantIndexOp>(loc, 0);
306 Value memref = xferOp.getSource();
307 return b.
create<scf::IfOp>(
311 if (compatibleMemRefType != xferOp.getShapedType())
312 res = b.
create<memref::CastOp>(loc, compatibleMemRefType, memref);
314 viewAndIndices.insert(viewAndIndices.end(), xferOp.getIndices().begin(),
315 xferOp.getIndices().end());
316 b.
create<scf::YieldOp>(loc, viewAndIndices);
320 Value vector = cast<VectorTransferOpInterface>(newXfer).vector();
321 b.
create<memref::StoreOp>(
323 b.
create<vector::TypeCastOp>(
324 loc, MemRefType::get({}, vector.
getType()), alloc));
327 b.
create<memref::CastOp>(loc, compatibleMemRefType, alloc);
329 viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
331 b.
create<scf::YieldOp>(loc, viewAndIndices);
353 MemRefType compatibleMemRefType,
Value alloc) {
355 Value zero = b.
create<arith::ConstantIndexOp>(loc, 0);
356 Value memref = xferOp.getSource();
362 if (compatibleMemRefType != xferOp.getShapedType())
363 res = b.
create<memref::CastOp>(loc, compatibleMemRefType, memref);
365 viewAndIndices.insert(viewAndIndices.end(),
366 xferOp.getIndices().begin(),
367 xferOp.getIndices().end());
368 b.
create<scf::YieldOp>(loc, viewAndIndices);
372 b.
create<memref::CastOp>(loc, compatibleMemRefType, alloc);
374 viewAndIndices.insert(viewAndIndices.end(),
375 xferOp.getTransferRank(), zero);
376 b.
create<scf::YieldOp>(loc, viewAndIndices);
395 vector::TransferWriteOp xferOp,
398 auto notInBounds = b.
create<arith::XOrIOp>(
399 loc, inBoundsCond, b.
create<arith::ConstantIntOp>(loc,
true, 1));
403 rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
405 b.
create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
423 vector::TransferWriteOp xferOp,
427 auto notInBounds = b.
create<arith::XOrIOp>(
428 loc, inBoundsCond, b.
create<arith::ConstantIntOp>(loc,
true, 1));
433 b.
create<vector::TypeCastOp>(
434 loc, MemRefType::get({}, xferOp.getVector().getType()), alloc),
436 mapping.
map(xferOp.getVector(), load);
437 b.
clone(*xferOp.getOperation(), mapping);
452 if (!isa<scf::ForOp, AffineForOp>(parent))
455 assert(scope &&
"Expected op to be inside automatic allocation scope");
527 if (
options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) {
529 xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr);
538 "Expected splitFullAndPartialTransferPrecondition to hold");
540 auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
541 auto xferWriteOp = dyn_cast<vector::TransferWriteOp>(xferOp.getOperation());
543 if (!(xferReadOp || xferWriteOp))
545 if (xferWriteOp && xferWriteOp.getMask())
547 if (xferReadOp && xferReadOp.getMask())
554 b, cast<VectorTransferOpInterface>(xferOp.getOperation()));
564 "AutomaticAllocationScope with >1 regions");
566 auto shape = xferOp.getVectorType().getShape();
567 Type elementType = xferOp.getVectorType().getElementType();
569 MemRefType::get(shape, elementType),
573 MemRefType compatibleMemRefType =
576 if (!compatibleMemRefType)
581 returnTypes[0] = compatibleMemRefType;
583 if (
auto xferReadOp =
584 dyn_cast<vector::TransferReadOp>(xferOp.getOperation())) {
586 scf::IfOp fullPartialIfOp =
587 options.vectorTransferSplit == VectorTransferSplit::VectorTransfer
590 compatibleMemRefType, alloc)
592 inBoundsCond, compatibleMemRefType,
595 *ifOp = fullPartialIfOp;
598 for (
unsigned i = 0, e = returnTypes.size(); i != e; ++i)
599 xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
602 xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr);
608 auto xferWriteOp = cast<vector::TransferWriteOp>(xferOp.getOperation());
612 b, xferWriteOp, returnTypes, inBoundsCond, compatibleMemRefType, alloc);
618 mapping.
map(xferWriteOp.getSource(), memrefAndIndices.front());
619 mapping.
map(xferWriteOp.getIndices(), memrefAndIndices.drop_front());
621 clone->
setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr);
625 if (
options.vectorTransferSplit == VectorTransferSplit::VectorTransfer)
638 struct VectorTransferFullPartialRewriter :
public RewritePattern {
639 using FilterConstraintType =
642 explicit VectorTransferFullPartialRewriter(
645 FilterConstraintType filter =
646 [](VectorTransferOpInterface op) {
return success(); },
649 filter(std::move(filter)) {}
657 FilterConstraintType filter;
662 LogicalResult VectorTransferFullPartialRewriter::matchAndRewrite(
664 auto xferOp = dyn_cast<VectorTransferOpInterface>(op);
673 patterns.
add<VectorTransferFullPartialRewriter>(patterns.
getContext(),
static llvm::ManagedStatic< PassManagerOptions > options
static scf::IfOp createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc)
Given an xferOp for which:
static void createFullPartialVectorTransferWrite(RewriterBase &b, vector::TransferWriteOp xferOp, Value inBoundsCond, Value alloc)
Given an xferOp for which:
static ValueRange getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp, TypeRange returnTypes, Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc)
Given an xferOp for which:
static scf::IfOp createFullPartialVectorTransferRead(RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc)
Given an xferOp for which:
static LogicalResult splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp)
Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds masking) fast path and a ...
static std::pair< Value, Value > createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp, Value alloc)
Operates under a scoped context to build the intersection between the view xferOp....
static std::optional< int64_t > extractConstantIndex(Value v)
static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT)
Given two MemRefTypes aT and bT, return a MemRefType to which both can be cast.
static Value createInBoundsCond(RewriterBase &b, VectorTransferOpInterface xferOp)
Build the condition to ensure that a particular VectorTransferOpInterface is in-bounds.
static Value createFoldedSLE(RewriterBase &b, Value v, Value ub)
static Operation * getAutomaticAllocationScope(Operation *op)
Base type for affine expression.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI64IntegerAttr(int64_t value)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
A trait of region holding operations that define a new scope for automatic allocations,...
Operation is the basic unit of execution within MLIR.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePattern is the common base class for all DAG to DAG replacements.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
LogicalResult splitFullAndPartialTransfer(RewriterBase &b, VectorTransferOpInterface xferOp, VectorTransformsOptions options=VectorTransformsOptions(), scf::IfOp *ifOp=nullptr)
Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds masking) fastpath and a s...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source.
void populateVectorTransferFullPartialPatterns(RewritePatternSet &patterns, const VectorTransformsOptions &options)
Populate patterns with the following patterns.
This header declares functions that assit transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ValueRange operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.