30 for (
unsigned pos : permutation)
31 newInBoundsValues[pos] =
32 cast<BoolAttr>(attr.getValue()[index++]).getValue();
40 auto originalVecType = cast<VectorType>(vec.
getType());
42 newShape.append(originalVecType.getShape().begin(),
43 originalVecType.getShape().end());
44 VectorType newVecType =
46 return builder.
create<vector::BroadcastOp>(loc, newVecType, vec);
55 for (int64_t i = addedRank,
56 e = broadcasted.
getType().
cast<VectorType>().getRank();
58 permutation.push_back(i);
59 for (int64_t i = 0; i < addedRank; ++i)
60 permutation.push_back(i);
61 return builder.
create<vector::TransposeOp>(loc, broadcasted, permutation);
88 struct TransferReadPermutationLowering
95 if (op.getTransferRank() == 0)
104 op,
"map is not permutable to minor identity, apply another pattern");
118 ArrayRef<bool> originalScalableDims = op.getVectorType().getScalableDims();
121 newVectorShape[pos.value()] = originalShape[pos.index()];
122 newScalableDims[pos.value()] = originalScalableDims[pos.index()];
126 ArrayAttr newInBoundsAttr =
128 rewriter, op.getInBounds().value(), permutation)
133 newVectorShape, op.getVectorType().getElementType(), newScalableDims);
134 Value newRead = rewriter.
create<vector::TransferReadOp>(
135 op.
getLoc(), newReadType, op.getSource(), op.getIndices(),
163 struct TransferWritePermutationLowering
170 if (op.getTransferRank() == 0)
180 op,
"map is not permutable to minor identity, apply another pattern");
190 llvm::transform(permutationMap.
getResults(), std::back_inserter(indices),
192 return dyn_cast<AffineDimExpr>(expr).getPosition();
196 ArrayAttr newInBoundsAttr =
198 rewriter, op.getInBounds().value(), permutation)
202 Value newVec = rewriter.
create<vector::TransposeOp>(
203 op.
getLoc(), op.getVector(), indices);
208 op.getMask(), newInBoundsAttr);
229 struct TransferWriteNonPermutationLowering
236 if (op.getTransferRank() == 0)
244 "map is already permutable to minor identity, apply another pattern");
251 foundDim[cast<AffineDimExpr>(exp).getPosition()] =
true;
253 bool foundFirstDim =
false;
255 for (
size_t i = 0; i < foundDim.size(); i++) {
257 foundFirstDim =
true;
264 missingInnerDim.push_back(i);
269 missingInnerDim.size());
274 missingInnerDim.size());
280 for (int64_t i = 0, e = op.getVectorType().getRank(); i < e; ++i) {
281 newInBoundsValues.push_back(op.isDimInBounds(i));
286 newMask, newInBoundsAttr);
299 struct TransferOpReduceRank :
public OpRewritePattern<vector::TransferReadOp> {
305 if (op.getTransferRank() == 0)
309 unsigned numLeadingBroadcast = 0;
311 auto dimExpr = dyn_cast<AffineConstantExpr>(expr);
312 if (!dimExpr || dimExpr.getValue() != 0)
314 numLeadingBroadcast++;
317 if (numLeadingBroadcast == 0)
320 VectorType originalVecType = op.getVectorType();
321 unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast;
330 op,
"map is not a minor identity with broadcasting");
336 if (reducedShapeRank == 0) {
338 if (isa<TensorType>(op.getShapedType())) {
339 newRead = rewriter.
create<tensor::ExtractOp>(
340 op.
getLoc(), op.getSource(), op.getIndices());
342 newRead = rewriter.
create<memref::LoadOp>(
343 op.
getLoc(), originalVecType.getElementType(), op.getSource(),
352 originalVecType.getShape().take_back(reducedShapeRank));
354 originalVecType.getScalableDims().take_back(reducedShapeRank));
356 if (newShape.empty())
360 newShape, originalVecType.getElementType(), newScalableDims);
361 ArrayAttr newInBoundsAttr =
364 op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
366 Value newRead = rewriter.
create<vector::TransferReadOp>(
367 op.
getLoc(), newReadType, op.getSource(), op.getIndices(),
381 .
add<TransferReadPermutationLowering, TransferWritePermutationLowering,
382 TransferOpReduceRank, TransferWriteNonPermutationLowering>(
399 struct TransferReadToVectorLoadLowering
401 TransferReadToVectorLoadLowering(
MLIRContext *context,
402 std::optional<unsigned> maxRank,
405 maxTransferRank(maxRank) {}
409 if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) {
411 read,
"vector type is greater than max transfer rank");
418 if (!read.getPermutationMap().isMinorIdentityWithBroadcasting(
422 auto memRefType = dyn_cast<MemRefType>(read.getShapedType());
435 for (
unsigned i : broadcastedDims)
436 unbroadcastedVectorShape[i] = 1;
437 VectorType unbroadcastedVectorType = read.getVectorType().cloneWith(
438 unbroadcastedVectorShape, read.getVectorType().getElementType());
442 auto memrefElTy = memRefType.getElementType();
443 if (isa<VectorType>(memrefElTy) && memrefElTy != unbroadcastedVectorType)
447 if (!isa<VectorType>(memrefElTy) &&
448 memrefElTy != read.getVectorType().getElementType())
452 if (read.hasOutOfBoundsDim())
457 if (read.getMask()) {
458 if (read.getVectorType().getRank() != 1)
461 read,
"vector type is not rank 1, can't create masked load, needs "
465 read.getLoc(), unbroadcastedVectorType, read.getPadding());
466 loadOp = rewriter.
create<vector::MaskedLoadOp>(
467 read.getLoc(), unbroadcastedVectorType, read.getSource(),
468 read.getIndices(), read.getMask(), fill);
470 loadOp = rewriter.
create<vector::LoadOp>(
471 read.getLoc(), unbroadcastedVectorType, read.getSource(),
476 if (!broadcastedDims.empty()) {
478 read, read.getVectorType(), loadOp->getResult(0));
480 rewriter.
replaceOp(read, loadOp->getResult(0));
486 std::optional<unsigned> maxTransferRank;
497 struct VectorLoadToMemrefLoadLowering
503 auto vecType = loadOp.getVectorType();
504 if (vecType.getNumElements() != 1)
507 auto memrefLoad = rewriter.
create<memref::LoadOp>(
508 loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices());
516 struct VectorStoreToMemrefStoreLowering
522 auto vecType = storeOp.getVectorType();
523 if (vecType.getNumElements() != 1)
527 if (vecType.getRank() == 0) {
529 extracted = rewriter.
create<vector::ExtractElementOp>(
530 storeOp.getLoc(), storeOp.getValueToStore());
533 extracted = rewriter.
create<vector::ExtractOp>(
534 storeOp.getLoc(), storeOp.getValueToStore(), indices);
538 storeOp, extracted, storeOp.getBase(), storeOp.getIndices());
551 struct TransferWriteToVectorStoreLowering
553 TransferWriteToVectorStoreLowering(
MLIRContext *context,
554 std::optional<unsigned> maxRank,
557 maxTransferRank(maxRank) {}
561 if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) {
563 write,
"vector type is greater than max transfer rank");
569 !write.getPermutationMap().isMinorIdentity())
571 diag <<
"permutation map is not minor identity: " << write;
574 auto memRefType = dyn_cast<MemRefType>(write.getShapedType());
577 diag <<
"not a memref type: " << write;
583 diag <<
"most minor stride is not 1: " << write;
588 auto memrefElTy = memRefType.getElementType();
589 if (isa<VectorType>(memrefElTy) && memrefElTy != write.getVectorType())
591 diag <<
"elemental type mismatch: " << write;
595 if (!isa<VectorType>(memrefElTy) &&
596 memrefElTy != write.getVectorType().getElementType())
598 diag <<
"elemental type mismatch: " << write;
602 if (write.hasOutOfBoundsDim())
604 diag <<
"out of bounds dim: " << write;
606 if (write.getMask()) {
607 if (write.getVectorType().getRank() != 1)
611 diag <<
"vector type is not rank 1, can't create masked store, "
612 "needs VectorToSCF: "
617 write, write.getSource(), write.getIndices(), write.getMask(),
621 write, write.getVector(), write.getSource(), write.getIndices());
626 std::optional<unsigned> maxTransferRank;
633 patterns.
add<TransferReadToVectorLoadLowering,
634 TransferWriteToVectorStoreLowering>(patterns.
getContext(),
635 maxTransferRank, benefit);
637 .
add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
static ArrayAttr inverseTransposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr, const SmallVector< unsigned > &permutation)
Transpose a vector transfer op's in_bounds attribute by applying reverse permutation based on the giv...
static Value extendMaskRank(OpBuilder &builder, Location loc, Value vec, int64_t addedRank)
Extend the rank of a vector Value by addedRanks by adding inner unit dimensions.
static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec, int64_t addedRank)
Extend the rank of a vector Value by addedRanks by adding outer unit dimensions.
static std::string diag(const llvm::Value &value)
static ArrayRef< int64_t > vectorShape(Type type)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > *broadcastedDims=nullptr) const
Returns true if this affine map is a minor identity up to broadcasted dimensions which are indicated ...
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
unsigned getNumResults() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isIdentity() const
Returns true if this affine map is an identity affine map.
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, std::optional< unsigned > maxTransferRank=std::nullopt, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
bool isLastMemrefDimUnitStride(MemRefType type)
Return "true" if the last dimension of the given type has a static unit stride.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...