30 for (
unsigned pos : permutation)
31 newInBoundsValues[pos] =
32 cast<BoolAttr>(attr.getValue()[index++]).getValue();
40 auto originalVecType = cast<VectorType>(vec.
getType());
42 newShape.append(originalVecType.getShape().begin(),
43 originalVecType.getShape().end());
46 newScalableDims.append(originalVecType.getScalableDims().begin(),
47 originalVecType.getScalableDims().end());
49 newShape, originalVecType.getElementType(), newScalableDims);
50 return builder.
create<vector::BroadcastOp>(loc, newVecType, vec);
59 for (int64_t i = addedRank,
60 e = cast<VectorType>(broadcasted.
getType()).getRank();
62 permutation.push_back(i);
63 for (int64_t i = 0; i < addedRank; ++i)
64 permutation.push_back(i);
65 return builder.
create<vector::TransposeOp>(loc, broadcasted, permutation);
92 struct TransferReadPermutationLowering
94 using MaskableOpRewritePattern::MaskableOpRewritePattern;
96 FailureOr<mlir::Value>
97 matchAndRewriteMaskableOp(vector::TransferReadOp op,
98 MaskingOpInterface maskOp,
101 if (op.getTransferRank() == 0)
113 op,
"map is not permutable to minor identity, apply another pattern");
127 ArrayRef<bool> originalScalableDims = op.getVectorType().getScalableDims();
130 newVectorShape[pos.value()] = originalShape[pos.index()];
131 newScalableDims[pos.value()] = originalScalableDims[pos.index()];
135 ArrayAttr newInBoundsAttr =
140 newVectorShape, op.getVectorType().getElementType(), newScalableDims);
141 Value newRead = rewriter.
create<vector::TransferReadOp>(
142 op.
getLoc(), newReadType, op.getSource(), op.getIndices(),
149 .
create<vector::TransposeOp>(op.
getLoc(), newRead, transposePerm)
170 struct TransferWritePermutationLowering
172 using MaskableOpRewritePattern::MaskableOpRewritePattern;
174 FailureOr<mlir::Value>
175 matchAndRewriteMaskableOp(vector::TransferWriteOp op,
176 MaskingOpInterface maskOp,
179 if (op.getTransferRank() == 0)
192 op,
"map is not permutable to minor identity, apply another pattern");
202 llvm::transform(permutationMap.
getResults(), std::back_inserter(indices),
204 return dyn_cast<AffineDimExpr>(expr).getPosition();
208 ArrayAttr newInBoundsAttr =
212 Value newVec = rewriter.
create<vector::TransposeOp>(
213 op.
getLoc(), op.getVector(), indices);
216 auto newWrite = rewriter.
create<vector::TransferWriteOp>(
217 op.
getLoc(), newVec, op.getSource(), op.getIndices(),
219 if (newWrite.hasPureTensorSemantics())
242 struct TransferWriteNonPermutationLowering
244 using MaskableOpRewritePattern::MaskableOpRewritePattern;
246 FailureOr<mlir::Value>
247 matchAndRewriteMaskableOp(vector::TransferWriteOp op,
248 MaskingOpInterface maskOp,
251 if (op.getTransferRank() == 0)
262 "map is already permutable to minor identity, apply another pattern");
269 foundDim[cast<AffineDimExpr>(exp).getPosition()] =
true;
271 bool foundFirstDim =
false;
273 for (
size_t i = 0; i < foundDim.size(); i++) {
275 foundFirstDim =
true;
282 missingInnerDim.push_back(i);
287 missingInnerDim.size());
292 missingInnerDim.size());
298 for (int64_t i = 0, e = op.getVectorType().getRank(); i < e; ++i) {
299 newInBoundsValues.push_back(op.isDimInBounds(i));
302 auto newWrite = rewriter.
create<vector::TransferWriteOp>(
303 op.
getLoc(), newVec, op.getSource(), op.getIndices(),
305 if (newWrite.hasPureTensorSemantics())
321 struct TransferOpReduceRank
323 using MaskableOpRewritePattern::MaskableOpRewritePattern;
325 FailureOr<mlir::Value>
326 matchAndRewriteMaskableOp(vector::TransferReadOp op,
327 MaskingOpInterface maskOp,
330 if (op.getTransferRank() == 0)
337 unsigned numLeadingBroadcast = 0;
339 auto dimExpr = dyn_cast<AffineConstantExpr>(expr);
340 if (!dimExpr || dimExpr.getValue() != 0)
342 numLeadingBroadcast++;
345 if (numLeadingBroadcast == 0)
348 VectorType originalVecType = op.getVectorType();
349 unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast;
358 op,
"map is not a minor identity with broadcasting");
364 if (reducedShapeRank == 0) {
366 if (isa<TensorType>(op.getShapedType())) {
367 newRead = rewriter.
create<tensor::ExtractOp>(
368 op.
getLoc(), op.getSource(), op.getIndices());
370 newRead = rewriter.
create<memref::LoadOp>(
371 op.
getLoc(), originalVecType.getElementType(), op.getSource(),
375 .
create<vector::BroadcastOp>(op.
getLoc(), originalVecType, newRead)
380 originalVecType.getShape().take_back(reducedShapeRank));
382 originalVecType.getScalableDims().take_back(reducedShapeRank));
384 if (newShape.empty())
388 newShape, originalVecType.getElementType(), newScalableDims);
389 ArrayAttr newInBoundsAttr =
392 op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
394 Value newRead = rewriter.
create<vector::TransferReadOp>(
395 op.
getLoc(), newReadType, op.getSource(), op.getIndices(),
399 .
create<vector::BroadcastOp>(op.
getLoc(), originalVecType, newRead)
409 .
add<TransferReadPermutationLowering, TransferWritePermutationLowering,
410 TransferOpReduceRank, TransferWriteNonPermutationLowering>(
427 struct TransferReadToVectorLoadLowering
429 TransferReadToVectorLoadLowering(
MLIRContext *context,
430 std::optional<unsigned> maxRank,
433 maxTransferRank(maxRank) {}
435 FailureOr<mlir::Value>
436 matchAndRewriteMaskableOp(vector::TransferReadOp read,
437 MaskingOpInterface maskOp,
439 if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) {
441 read,
"vector type is greater than max transfer rank");
450 if (!read.getPermutationMap().isMinorIdentityWithBroadcasting(
454 auto memRefType = dyn_cast<MemRefType>(read.getShapedType());
466 for (
unsigned i : broadcastedDims)
467 unbroadcastedVectorShape[i] = 1;
468 VectorType unbroadcastedVectorType = read.getVectorType().cloneWith(
469 unbroadcastedVectorShape, read.getVectorType().getElementType());
473 auto memrefElTy = memRefType.getElementType();
474 if (isa<VectorType>(memrefElTy) && memrefElTy != unbroadcastedVectorType)
478 if (!isa<VectorType>(memrefElTy) &&
479 memrefElTy != read.getVectorType().getElementType())
483 if (read.hasOutOfBoundsDim())
488 if (read.getMask()) {
489 if (read.getVectorType().getRank() != 1)
492 read,
"vector type is not rank 1, can't create masked load, needs "
496 read.getLoc(), unbroadcastedVectorType, read.getPadding());
497 res = rewriter.
create<vector::MaskedLoadOp>(
498 read.getLoc(), unbroadcastedVectorType, read.getSource(),
499 read.getIndices(), read.getMask(), fill);
501 res = rewriter.
create<vector::LoadOp>(
502 read.getLoc(), unbroadcastedVectorType, read.getSource(),
507 if (!broadcastedDims.empty())
508 res = rewriter.
create<vector::BroadcastOp>(
509 read.getLoc(), read.getVectorType(), res->getResult(0));
513 std::optional<unsigned> maxTransferRank;
524 struct VectorLoadToMemrefLoadLowering
528 LogicalResult matchAndRewrite(vector::LoadOp loadOp,
530 auto vecType = loadOp.getVectorType();
531 if (vecType.getNumElements() != 1)
534 auto memrefLoad = rewriter.
create<memref::LoadOp>(
535 loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices());
543 struct VectorStoreToMemrefStoreLowering
547 LogicalResult matchAndRewrite(vector::StoreOp storeOp,
549 auto vecType = storeOp.getVectorType();
550 if (vecType.getNumElements() != 1)
554 if (vecType.getRank() == 0) {
556 extracted = rewriter.
create<vector::ExtractElementOp>(
557 storeOp.getLoc(), storeOp.getValueToStore());
560 extracted = rewriter.
create<vector::ExtractOp>(
561 storeOp.getLoc(), storeOp.getValueToStore(), indices);
565 storeOp, extracted, storeOp.getBase(), storeOp.getIndices());
578 struct TransferWriteToVectorStoreLowering
580 TransferWriteToVectorStoreLowering(
MLIRContext *context,
581 std::optional<unsigned> maxRank,
584 maxTransferRank(maxRank) {}
586 FailureOr<mlir::Value>
587 matchAndRewriteMaskableOp(vector::TransferWriteOp write,
588 MaskingOpInterface maskOp,
590 if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) {
592 write,
"vector type is greater than max transfer rank");
600 !write.getPermutationMap().isMinorIdentity())
602 diag <<
"permutation map is not minor identity: " << write;
605 auto memRefType = dyn_cast<MemRefType>(write.getShapedType());
608 diag <<
"not a memref type: " << write;
614 diag <<
"most minor stride is not 1: " << write;
619 auto memrefElTy = memRefType.getElementType();
620 if (isa<VectorType>(memrefElTy) && memrefElTy != write.getVectorType())
622 diag <<
"elemental type mismatch: " << write;
626 if (!isa<VectorType>(memrefElTy) &&
627 memrefElTy != write.getVectorType().getElementType())
629 diag <<
"elemental type mismatch: " << write;
633 if (write.hasOutOfBoundsDim())
635 diag <<
"out of bounds dim: " << write;
637 if (write.getMask()) {
638 if (write.getVectorType().getRank() != 1)
642 diag <<
"vector type is not rank 1, can't create masked store, "
643 "needs VectorToSCF: "
647 rewriter.
create<vector::MaskedStoreOp>(
648 write.getLoc(), write.getSource(), write.getIndices(),
649 write.getMask(), write.getVector());
651 rewriter.
create<vector::StoreOp>(write.getLoc(), write.getVector(),
652 write.getSource(), write.getIndices());
659 std::optional<unsigned> maxTransferRank;
666 patterns.
add<TransferReadToVectorLoadLowering,
667 TransferWriteToVectorStoreLowering>(patterns.
getContext(),
668 maxTransferRank, benefit);
670 .
add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
static ArrayAttr inverseTransposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr, const SmallVector< unsigned > &permutation)
Transpose a vector transfer op's in_bounds attribute by applying reverse permutation based on the giv...
static Value extendMaskRank(OpBuilder &builder, Location loc, Value vec, int64_t addedRank)
Extend the rank of a vector Value by addedRanks by adding inner unit dimensions.
static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec, int64_t addedRank)
Extend the rank of a vector Value by addedRanks by adding outer unit dimensions.
static std::string diag(const llvm::Value &value)
static VectorShape vectorShape(Type type)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > *broadcastedDims=nullptr) const
Returns true if this affine map is a minor identity up to broadcasted dimensions which are indicated ...
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
unsigned getNumResults() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isIdentity() const
Returns true if this affine map is an identity affine map.
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, std::optional< unsigned > maxTransferRank=std::nullopt, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
bool isLastMemrefDimUnitStride(MemRefType type)
Return "true" if the last dimension of the given type has a static unit stride.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.