30 for (
unsigned pos : permutation)
31 newInBoundsValues[pos] =
32 cast<BoolAttr>(attr.getValue()[index++]).getValue();
40 auto originalVecType = cast<VectorType>(vec.
getType());
42 newShape.append(originalVecType.getShape().begin(),
43 originalVecType.getShape().end());
46 newScalableDims.append(originalVecType.getScalableDims().begin(),
47 originalVecType.getScalableDims().end());
49 newShape, originalVecType.getElementType(), newScalableDims);
50 return builder.
create<vector::BroadcastOp>(loc, newVecType, vec);
59 for (int64_t i = addedRank,
60 e = cast<VectorType>(broadcasted.
getType()).getRank();
62 permutation.push_back(i);
63 for (int64_t i = 0; i < addedRank; ++i)
64 permutation.push_back(i);
65 return builder.
create<vector::TransposeOp>(loc, broadcasted, permutation);
92 struct TransferReadPermutationLowering
99 if (op.getTransferRank() == 0)
108 op,
"map is not permutable to minor identity, apply another pattern");
122 ArrayRef<bool> originalScalableDims = op.getVectorType().getScalableDims();
125 newVectorShape[pos.value()] = originalShape[pos.index()];
126 newScalableDims[pos.value()] = originalScalableDims[pos.index()];
130 ArrayAttr newInBoundsAttr =
132 rewriter, op.getInBounds().value(), permutation)
137 newVectorShape, op.getVectorType().getElementType(), newScalableDims);
138 Value newRead = rewriter.
create<vector::TransferReadOp>(
139 op.
getLoc(), newReadType, op.getSource(), op.getIndices(),
167 struct TransferWritePermutationLowering
174 if (op.getTransferRank() == 0)
184 op,
"map is not permutable to minor identity, apply another pattern");
194 llvm::transform(permutationMap.
getResults(), std::back_inserter(indices),
196 return dyn_cast<AffineDimExpr>(expr).getPosition();
200 ArrayAttr newInBoundsAttr =
202 rewriter, op.getInBounds().value(), permutation)
206 Value newVec = rewriter.
create<vector::TransposeOp>(
207 op.
getLoc(), op.getVector(), indices);
212 op.getMask(), newInBoundsAttr);
233 struct TransferWriteNonPermutationLowering
240 if (op.getTransferRank() == 0)
248 "map is already permutable to minor identity, apply another pattern");
255 foundDim[cast<AffineDimExpr>(exp).getPosition()] =
true;
257 bool foundFirstDim =
false;
259 for (
size_t i = 0; i < foundDim.size(); i++) {
261 foundFirstDim =
true;
268 missingInnerDim.push_back(i);
273 missingInnerDim.size());
278 missingInnerDim.size());
284 for (int64_t i = 0, e = op.getVectorType().getRank(); i < e; ++i) {
285 newInBoundsValues.push_back(op.isDimInBounds(i));
290 newMask, newInBoundsAttr);
303 struct TransferOpReduceRank :
public OpRewritePattern<vector::TransferReadOp> {
309 if (op.getTransferRank() == 0)
313 unsigned numLeadingBroadcast = 0;
315 auto dimExpr = dyn_cast<AffineConstantExpr>(expr);
316 if (!dimExpr || dimExpr.getValue() != 0)
318 numLeadingBroadcast++;
321 if (numLeadingBroadcast == 0)
324 VectorType originalVecType = op.getVectorType();
325 unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast;
334 op,
"map is not a minor identity with broadcasting");
340 if (reducedShapeRank == 0) {
342 if (isa<TensorType>(op.getShapedType())) {
343 newRead = rewriter.
create<tensor::ExtractOp>(
344 op.
getLoc(), op.getSource(), op.getIndices());
346 newRead = rewriter.
create<memref::LoadOp>(
347 op.
getLoc(), originalVecType.getElementType(), op.getSource(),
356 originalVecType.getShape().take_back(reducedShapeRank));
358 originalVecType.getScalableDims().take_back(reducedShapeRank));
360 if (newShape.empty())
364 newShape, originalVecType.getElementType(), newScalableDims);
365 ArrayAttr newInBoundsAttr =
368 op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
370 Value newRead = rewriter.
create<vector::TransferReadOp>(
371 op.
getLoc(), newReadType, op.getSource(), op.getIndices(),
385 .
add<TransferReadPermutationLowering, TransferWritePermutationLowering,
386 TransferOpReduceRank, TransferWriteNonPermutationLowering>(
403 struct TransferReadToVectorLoadLowering
405 TransferReadToVectorLoadLowering(
MLIRContext *context,
406 std::optional<unsigned> maxRank,
409 maxTransferRank(maxRank) {}
413 if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) {
415 read,
"vector type is greater than max transfer rank");
422 if (!read.getPermutationMap().isMinorIdentityWithBroadcasting(
426 auto memRefType = dyn_cast<MemRefType>(read.getShapedType());
439 for (
unsigned i : broadcastedDims)
440 unbroadcastedVectorShape[i] = 1;
441 VectorType unbroadcastedVectorType = read.getVectorType().cloneWith(
442 unbroadcastedVectorShape, read.getVectorType().getElementType());
446 auto memrefElTy = memRefType.getElementType();
447 if (isa<VectorType>(memrefElTy) && memrefElTy != unbroadcastedVectorType)
451 if (!isa<VectorType>(memrefElTy) &&
452 memrefElTy != read.getVectorType().getElementType())
456 if (read.hasOutOfBoundsDim())
461 if (read.getMask()) {
462 if (read.getVectorType().getRank() != 1)
465 read,
"vector type is not rank 1, can't create masked load, needs "
469 read.getLoc(), unbroadcastedVectorType, read.getPadding());
470 loadOp = rewriter.
create<vector::MaskedLoadOp>(
471 read.getLoc(), unbroadcastedVectorType, read.getSource(),
472 read.getIndices(), read.getMask(), fill);
474 loadOp = rewriter.
create<vector::LoadOp>(
475 read.getLoc(), unbroadcastedVectorType, read.getSource(),
480 if (!broadcastedDims.empty()) {
482 read, read.getVectorType(), loadOp->getResult(0));
484 rewriter.
replaceOp(read, loadOp->getResult(0));
490 std::optional<unsigned> maxTransferRank;
501 struct VectorLoadToMemrefLoadLowering
507 auto vecType = loadOp.getVectorType();
508 if (vecType.getNumElements() != 1)
511 auto memrefLoad = rewriter.
create<memref::LoadOp>(
512 loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices());
520 struct VectorStoreToMemrefStoreLowering
526 auto vecType = storeOp.getVectorType();
527 if (vecType.getNumElements() != 1)
531 if (vecType.getRank() == 0) {
533 extracted = rewriter.
create<vector::ExtractElementOp>(
534 storeOp.getLoc(), storeOp.getValueToStore());
537 extracted = rewriter.
create<vector::ExtractOp>(
538 storeOp.getLoc(), storeOp.getValueToStore(), indices);
542 storeOp, extracted, storeOp.getBase(), storeOp.getIndices());
555 struct TransferWriteToVectorStoreLowering
557 TransferWriteToVectorStoreLowering(
MLIRContext *context,
558 std::optional<unsigned> maxRank,
561 maxTransferRank(maxRank) {}
565 if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) {
567 write,
"vector type is greater than max transfer rank");
573 !write.getPermutationMap().isMinorIdentity())
575 diag <<
"permutation map is not minor identity: " << write;
578 auto memRefType = dyn_cast<MemRefType>(write.getShapedType());
581 diag <<
"not a memref type: " << write;
587 diag <<
"most minor stride is not 1: " << write;
592 auto memrefElTy = memRefType.getElementType();
593 if (isa<VectorType>(memrefElTy) && memrefElTy != write.getVectorType())
595 diag <<
"elemental type mismatch: " << write;
599 if (!isa<VectorType>(memrefElTy) &&
600 memrefElTy != write.getVectorType().getElementType())
602 diag <<
"elemental type mismatch: " << write;
606 if (write.hasOutOfBoundsDim())
608 diag <<
"out of bounds dim: " << write;
610 if (write.getMask()) {
611 if (write.getVectorType().getRank() != 1)
615 diag <<
"vector type is not rank 1, can't create masked store, "
616 "needs VectorToSCF: "
621 write, write.getSource(), write.getIndices(), write.getMask(),
625 write, write.getVector(), write.getSource(), write.getIndices());
630 std::optional<unsigned> maxTransferRank;
637 patterns.
add<TransferReadToVectorLoadLowering,
638 TransferWriteToVectorStoreLowering>(patterns.
getContext(),
639 maxTransferRank, benefit);
641 .
add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
static ArrayAttr inverseTransposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr, const SmallVector< unsigned > &permutation)
Transpose a vector transfer op's in_bounds attribute by applying reverse permutation based on the giv...
static Value extendMaskRank(OpBuilder &builder, Location loc, Value vec, int64_t addedRank)
Extend the rank of a vector Value by addedRanks by adding inner unit dimensions.
static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec, int64_t addedRank)
Extend the rank of a vector Value by addedRanks by adding outer unit dimensions.
static std::string diag(const llvm::Value &value)
static VectorShape vectorShape(Type type)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > *broadcastedDims=nullptr) const
Returns true if this affine map is a minor identity up to broadcasted dimensions which are indicated ...
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
unsigned getNumResults() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isIdentity() const
Returns true if this affine map is an identity affine map.
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, std::optional< unsigned > maxTransferRank=std::nullopt, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
bool isLastMemrefDimUnitStride(MemRefType type)
Return "true" if the last dimension of the given type has a static unit stride.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...