27 for (
unsigned pos : permutation)
28 newInBoundsValues[pos] =
29 cast<BoolAttr>(attr.getValue()[
index++]).getValue();
37 auto originalVecType = cast<VectorType>(vec.
getType());
39 newShape.append(originalVecType.getShape().begin(),
40 originalVecType.getShape().end());
43 newScalableDims.append(originalVecType.getScalableDims().begin(),
44 originalVecType.getScalableDims().end());
45 VectorType newVecType = VectorType::get(
46 newShape, originalVecType.getElementType(), newScalableDims);
47 return vector::BroadcastOp::create(builder, loc, newVecType, vec);
57 e = cast<VectorType>(broadcasted.
getType()).getRank();
59 permutation.push_back(i);
60 for (
int64_t i = 0; i < addedRank; ++i)
61 permutation.push_back(i);
62 return vector::TransposeOp::create(builder, loc, broadcasted, permutation);
89struct TransferReadPermutationLowering
91 using MaskableOpRewritePattern::MaskableOpRewritePattern;
93 FailureOr<mlir::Value>
94 matchAndRewriteMaskableOp(vector::TransferReadOp op,
95 MaskingOpInterface maskOp,
96 PatternRewriter &rewriter)
const override {
98 if (op.getTransferRank() == 0)
104 SmallVector<unsigned> permutation;
105 AffineMap map = op.getPermutationMap();
110 op,
"map is not permutable to minor identity, apply another pattern");
112 AffineMap permutationMap =
120 AffineMap newMap = permutationMap.
compose(map);
122 ArrayRef<int64_t> originalShape = op.getVectorType().getShape();
123 SmallVector<int64_t> newVectorShape(originalShape.size());
124 ArrayRef<bool> originalScalableDims = op.getVectorType().getScalableDims();
125 SmallVector<bool> newScalableDims(originalShape.size());
126 for (
const auto &pos : llvm::enumerate(permutation)) {
127 newVectorShape[pos.value()] = originalShape[pos.index()];
128 newScalableDims[pos.value()] = originalScalableDims[pos.index()];
136 VectorType newReadType = VectorType::get(
137 newVectorShape, op.getVectorType().getElementType(), newScalableDims);
138 Value newRead = vector::TransferReadOp::create(
139 rewriter, op.getLoc(), newReadType, op.getBase(), op.getIndices(),
140 AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
144 SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
145 return vector::TransposeOp::create(rewriter, op.getLoc(), newRead,
167struct TransferWritePermutationLowering
169 using MaskableOpRewritePattern::MaskableOpRewritePattern;
171 FailureOr<mlir::Value>
172 matchAndRewriteMaskableOp(vector::TransferWriteOp op,
173 MaskingOpInterface maskOp,
174 PatternRewriter &rewriter)
const override {
176 if (op.getTransferRank() == 0)
182 SmallVector<unsigned> permutation;
189 op,
"map is not permutable to minor identity, apply another pattern");
200 [](AffineExpr expr) {
201 return dyn_cast<AffineDimExpr>(expr).getPosition();
209 Value newVec = vector::TransposeOp::create(rewriter, op.getLoc(),
213 auto newWrite = vector::TransferWriteOp::create(
214 rewriter, op.getLoc(), newVec, op.getBase(), op.getIndices(),
215 AffineMapAttr::get(newMap), op.getMask(), newInBoundsAttr);
216 if (newWrite.hasPureTensorSemantics())
217 return newWrite.getResult();
239struct TransferWriteNonPermutationLowering
241 using MaskableOpRewritePattern::MaskableOpRewritePattern;
243 FailureOr<mlir::Value>
244 matchAndRewriteMaskableOp(vector::TransferWriteOp op,
245 MaskingOpInterface maskOp,
246 PatternRewriter &rewriter)
const override {
248 if (op.getTransferRank() == 0)
254 SmallVector<unsigned> permutation;
259 "map is already permutable to minor identity, apply another pattern");
264 SmallVector<bool> foundDim(map.
getNumDims(),
false);
266 foundDim[cast<AffineDimExpr>(exp).getPosition()] =
true;
267 SmallVector<AffineExpr> exprs;
268 bool foundFirstDim =
false;
269 SmallVector<int64_t> missingInnerDim;
270 for (
size_t i = 0; i < foundDim.size(); i++) {
272 foundFirstDim =
true;
279 missingInnerDim.push_back(i);
284 missingInnerDim.size());
289 missingInnerDim.size());
294 SmallVector<bool> newInBoundsValues(missingInnerDim.size(),
true);
295 for (int64_t i = 0, e = op.getVectorType().getRank(); i < e; ++i) {
296 newInBoundsValues.push_back(op.isDimInBounds(i));
299 auto newWrite = vector::TransferWriteOp::create(
300 rewriter, op.getLoc(), newVec, op.getBase(), op.getIndices(),
301 AffineMapAttr::get(newMap), newMask, newInBoundsAttr);
302 if (newWrite.hasPureTensorSemantics())
303 return newWrite.getResult();
318struct TransferOpReduceRank
320 using MaskableOpRewritePattern::MaskableOpRewritePattern;
322 FailureOr<mlir::Value>
323 matchAndRewriteMaskableOp(vector::TransferReadOp op,
324 MaskingOpInterface maskOp,
325 PatternRewriter &rewriter)
const override {
327 if (op.getTransferRank() == 0)
334 unsigned numLeadingBroadcast = 0;
336 auto dimExpr = dyn_cast<AffineConstantExpr>(expr);
337 if (!dimExpr || dimExpr.getValue() != 0)
339 numLeadingBroadcast++;
342 if (numLeadingBroadcast == 0)
345 VectorType originalVecType = op.getVectorType();
346 unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast;
355 op,
"map is not a minor identity with broadcasting");
358 SmallVector<int64_t> newShape(
359 originalVecType.getShape().take_back(reducedShapeRank));
360 SmallVector<bool> newScalableDims(
361 originalVecType.getScalableDims().take_back(reducedShapeRank));
363 VectorType newReadType = VectorType::get(
364 newShape, originalVecType.getElementType(), newScalableDims);
368 op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
370 Value newRead = vector::TransferReadOp::create(
371 rewriter, op.getLoc(), newReadType, op.getBase(), op.getIndices(),
372 AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
374 return vector::BroadcastOp::create(rewriter, op.getLoc(), originalVecType,
385 .add<TransferReadPermutationLowering, TransferWritePermutationLowering,
386 TransferOpReduceRank, TransferWriteNonPermutationLowering>(
403struct TransferReadToVectorLoadLowering
405 TransferReadToVectorLoadLowering(
MLIRContext *context,
406 std::optional<unsigned> maxRank,
409 maxTransferRank(maxRank) {}
411 FailureOr<mlir::Value>
412 matchAndRewriteMaskableOp(vector::TransferReadOp read,
413 MaskingOpInterface maskOp,
415 if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) {
417 read,
"vector type is greater than max transfer rank");
422 SmallVector<unsigned> broadcastedDims;
426 if (!read.getPermutationMap().isMinorIdentityWithBroadcasting(
430 auto memRefType = dyn_cast<MemRefType>(read.getShapedType());
435 if (!memRefType.isLastDimUnitStride())
440 ArrayRef<int64_t>
vectorShape = read.getVectorType().getShape();
441 SmallVector<int64_t> unbroadcastedVectorShape(
vectorShape);
442 for (
unsigned i : broadcastedDims)
443 unbroadcastedVectorShape[i] = 1;
444 VectorType unbroadcastedVectorType = read.getVectorType().cloneWith(
445 unbroadcastedVectorShape, read.getVectorType().getElementType());
449 auto memrefElTy = memRefType.getElementType();
450 if (isa<VectorType>(memrefElTy) && memrefElTy != unbroadcastedVectorType)
454 if (!isa<VectorType>(memrefElTy) &&
455 memrefElTy != read.getVectorType().getElementType())
459 if (read.hasOutOfBoundsDim())
464 if (read.getMask()) {
465 if (read.getVectorType().getRank() != 1)
468 read,
"vector type is not rank 1, can't create masked load, needs "
471 Value fill = vector::BroadcastOp::create(
472 rewriter, read.getLoc(), unbroadcastedVectorType, read.getPadding());
473 res = vector::MaskedLoadOp::create(
474 rewriter, read.getLoc(), unbroadcastedVectorType, read.getBase(),
475 read.getIndices(), read.getMask(), fill);
477 res = vector::LoadOp::create(rewriter, read.getLoc(),
478 unbroadcastedVectorType, read.getBase(),
483 if (!broadcastedDims.empty())
484 res = vector::BroadcastOp::create(
485 rewriter, read.getLoc(), read.getVectorType(), res->
getResult(0));
489 std::optional<unsigned> maxTransferRank;
500struct TransferWriteToVectorStoreLowering
502 TransferWriteToVectorStoreLowering(MLIRContext *context,
503 std::optional<unsigned> maxRank,
504 PatternBenefit benefit = 1)
505 : MaskableOpRewritePattern<vector::TransferWriteOp>(context, benefit),
506 maxTransferRank(maxRank) {}
508 FailureOr<mlir::Value>
509 matchAndRewriteMaskableOp(vector::TransferWriteOp write,
510 MaskingOpInterface maskOp,
511 PatternRewriter &rewriter)
const override {
512 if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) {
514 write,
"vector type is greater than max transfer rank");
522 !write.getPermutationMap().isMinorIdentity())
524 diag <<
"permutation map is not minor identity: " << write;
527 auto memRefType = dyn_cast<MemRefType>(write.getShapedType());
530 diag <<
"not a memref type: " << write;
534 if (!memRefType.isLastDimUnitStride())
536 diag <<
"most minor stride is not 1: " << write;
541 auto memrefElTy = memRefType.getElementType();
542 if (isa<VectorType>(memrefElTy) && memrefElTy != write.getVectorType())
544 diag <<
"elemental type mismatch: " << write;
548 if (!isa<VectorType>(memrefElTy) &&
549 memrefElTy != write.getVectorType().getElementType())
551 diag <<
"elemental type mismatch: " << write;
555 if (write.hasOutOfBoundsDim())
557 diag <<
"out of bounds dim: " << write;
559 if (write.getMask()) {
560 if (write.getVectorType().getRank() != 1)
563 write.getLoc(), [=](Diagnostic &
diag) {
564 diag <<
"vector type is not rank 1, can't create masked store, "
565 "needs VectorToSCF: "
569 vector::MaskedStoreOp::create(rewriter, write.getLoc(), write.getBase(),
570 write.getIndices(), write.getMask(),
573 vector::StoreOp::create(rewriter, write.getLoc(), write.getVector(),
574 write.getBase(), write.getIndices());
581 std::optional<unsigned> maxTransferRank;
588 patterns.add<TransferReadToVectorLoadLowering,
589 TransferWriteToVectorStoreLowering>(
patterns.getContext(),
590 maxTransferRank, benefit);
static ArrayAttr inverseTransposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr, const SmallVector< unsigned > &permutation)
Transpose a vector transfer op's in_bounds attribute by applying reverse permutation based on the giv...
static Value extendMaskRank(OpBuilder &builder, Location loc, Value vec, int64_t addedRank)
Extend the rank of a vector Value by addedRanks by adding inner unit dimensions.
static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec, int64_t addedRank)
Extend the rank of a vector Value by addedRanks by adding outer unit dimensions.
static std::string diag(const llvm::Value &value)
static std::optional< VectorShape > vectorShape(Type type)
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > *broadcastedDims=nullptr) const
Returns true if this affine map is a minor identity up to broadcasted dimensions which are indicated ...
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
unsigned getNumResults() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isIdentity() const
Returns true if this affine map is an identity affine map.
AffineExpr getAffineDimExpr(unsigned position)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, std::optional< unsigned > maxTransferRank=std::nullopt, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
const FrozenRewritePatternSet & patterns
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.