27 for (
unsigned pos : permutation)
28 newInBoundsValues[pos] =
29 cast<BoolAttr>(attr.getValue()[index++]).getValue();
37 auto originalVecType = cast<VectorType>(vec.
getType());
39 newShape.append(originalVecType.getShape().begin(),
40 originalVecType.getShape().end());
43 newScalableDims.append(originalVecType.getScalableDims().begin(),
44 originalVecType.getScalableDims().end());
46 newShape, originalVecType.getElementType(), newScalableDims);
47 return vector::BroadcastOp::create(builder, loc, newVecType, vec);
56 for (int64_t i = addedRank,
57 e = cast<VectorType>(broadcasted.
getType()).getRank();
59 permutation.push_back(i);
60 for (int64_t i = 0; i < addedRank; ++i)
61 permutation.push_back(i);
62 return vector::TransposeOp::create(builder, loc, broadcasted, permutation);
89 struct TransferReadPermutationLowering
91 using MaskableOpRewritePattern::MaskableOpRewritePattern;
93 FailureOr<mlir::Value>
94 matchAndRewriteMaskableOp(vector::TransferReadOp op,
95 MaskingOpInterface maskOp,
98 if (op.getTransferRank() == 0)
110 op,
"map is not permutable to minor identity, apply another pattern");
124 ArrayRef<bool> originalScalableDims = op.getVectorType().getScalableDims();
127 newVectorShape[pos.value()] = originalShape[pos.index()];
128 newScalableDims[pos.value()] = originalScalableDims[pos.index()];
132 ArrayAttr newInBoundsAttr =
137 newVectorShape, op.getVectorType().getElementType(), newScalableDims);
138 Value newRead = vector::TransferReadOp::create(
139 rewriter, op.getLoc(), newReadType, op.getBase(), op.getIndices(),
145 return vector::TransposeOp::create(rewriter, op.getLoc(), newRead,
167 struct TransferWritePermutationLowering
169 using MaskableOpRewritePattern::MaskableOpRewritePattern;
171 FailureOr<mlir::Value>
172 matchAndRewriteMaskableOp(vector::TransferWriteOp op,
173 MaskingOpInterface maskOp,
176 if (op.getTransferRank() == 0)
189 op,
"map is not permutable to minor identity, apply another pattern");
199 llvm::transform(permutationMap.
getResults(), std::back_inserter(indices),
201 return dyn_cast<AffineDimExpr>(expr).getPosition();
205 ArrayAttr newInBoundsAttr =
209 Value newVec = vector::TransposeOp::create(rewriter, op.getLoc(),
210 op.getVector(), indices);
213 auto newWrite = vector::TransferWriteOp::create(
214 rewriter, op.getLoc(), newVec, op.getBase(), op.getIndices(),
216 if (newWrite.hasPureTensorSemantics())
217 return newWrite.getResult();
239 struct TransferWriteNonPermutationLowering
241 using MaskableOpRewritePattern::MaskableOpRewritePattern;
243 FailureOr<mlir::Value>
244 matchAndRewriteMaskableOp(vector::TransferWriteOp op,
245 MaskingOpInterface maskOp,
248 if (op.getTransferRank() == 0)
259 "map is already permutable to minor identity, apply another pattern");
266 foundDim[cast<AffineDimExpr>(exp).getPosition()] =
true;
268 bool foundFirstDim =
false;
270 for (
size_t i = 0; i < foundDim.size(); i++) {
272 foundFirstDim =
true;
279 missingInnerDim.push_back(i);
284 missingInnerDim.size());
289 missingInnerDim.size());
295 for (int64_t i = 0, e = op.getVectorType().getRank(); i < e; ++i) {
296 newInBoundsValues.push_back(op.isDimInBounds(i));
299 auto newWrite = vector::TransferWriteOp::create(
300 rewriter, op.getLoc(), newVec, op.getBase(), op.getIndices(),
302 if (newWrite.hasPureTensorSemantics())
303 return newWrite.getResult();
318 struct TransferOpReduceRank
320 using MaskableOpRewritePattern::MaskableOpRewritePattern;
322 FailureOr<mlir::Value>
323 matchAndRewriteMaskableOp(vector::TransferReadOp op,
324 MaskingOpInterface maskOp,
327 if (op.getTransferRank() == 0)
334 unsigned numLeadingBroadcast = 0;
336 auto dimExpr = dyn_cast<AffineConstantExpr>(expr);
337 if (!dimExpr || dimExpr.getValue() != 0)
339 numLeadingBroadcast++;
342 if (numLeadingBroadcast == 0)
345 VectorType originalVecType = op.getVectorType();
346 unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast;
355 op,
"map is not a minor identity with broadcasting");
359 originalVecType.getShape().take_back(reducedShapeRank));
361 originalVecType.getScalableDims().take_back(reducedShapeRank));
364 newShape, originalVecType.getElementType(), newScalableDims);
365 ArrayAttr newInBoundsAttr =
368 op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
370 Value newRead = vector::TransferReadOp::create(
371 rewriter, op.getLoc(), newReadType, op.getBase(), op.getIndices(),
374 return vector::BroadcastOp::create(rewriter, op.getLoc(), originalVecType,
385 .add<TransferReadPermutationLowering, TransferWritePermutationLowering,
386 TransferOpReduceRank, TransferWriteNonPermutationLowering>(
403 struct TransferReadToVectorLoadLowering
405 TransferReadToVectorLoadLowering(
MLIRContext *context,
406 std::optional<unsigned> maxRank,
409 maxTransferRank(maxRank) {}
411 FailureOr<mlir::Value>
412 matchAndRewriteMaskableOp(vector::TransferReadOp read,
413 MaskingOpInterface maskOp,
415 if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) {
417 read,
"vector type is greater than max transfer rank");
426 if (!read.getPermutationMap().isMinorIdentityWithBroadcasting(
430 auto memRefType = dyn_cast<MemRefType>(read.getShapedType());
435 if (!memRefType.isLastDimUnitStride())
442 for (
unsigned i : broadcastedDims)
443 unbroadcastedVectorShape[i] = 1;
444 VectorType unbroadcastedVectorType = read.getVectorType().cloneWith(
445 unbroadcastedVectorShape, read.getVectorType().getElementType());
449 auto memrefElTy = memRefType.getElementType();
450 if (isa<VectorType>(memrefElTy) && memrefElTy != unbroadcastedVectorType)
454 if (!isa<VectorType>(memrefElTy) &&
455 memrefElTy != read.getVectorType().getElementType())
459 if (read.hasOutOfBoundsDim())
464 if (read.getMask()) {
465 if (read.getVectorType().getRank() != 1)
468 read,
"vector type is not rank 1, can't create masked load, needs "
471 Value fill = vector::BroadcastOp::create(
472 rewriter, read.getLoc(), unbroadcastedVectorType, read.getPadding());
473 res = vector::MaskedLoadOp::create(
474 rewriter, read.getLoc(), unbroadcastedVectorType, read.getBase(),
475 read.getIndices(), read.getMask(), fill);
477 res = vector::LoadOp::create(rewriter, read.getLoc(),
478 unbroadcastedVectorType, read.getBase(),
483 if (!broadcastedDims.empty())
484 res = vector::BroadcastOp::create(
485 rewriter, read.getLoc(), read.getVectorType(), res->
getResult(0));
489 std::optional<unsigned> maxTransferRank;
500 struct TransferWriteToVectorStoreLowering
502 TransferWriteToVectorStoreLowering(
MLIRContext *context,
503 std::optional<unsigned> maxRank,
506 maxTransferRank(maxRank) {}
508 FailureOr<mlir::Value>
509 matchAndRewriteMaskableOp(vector::TransferWriteOp write,
510 MaskingOpInterface maskOp,
512 if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) {
514 write,
"vector type is greater than max transfer rank");
522 !write.getPermutationMap().isMinorIdentity())
524 diag <<
"permutation map is not minor identity: " << write;
527 auto memRefType = dyn_cast<MemRefType>(write.getShapedType());
530 diag <<
"not a memref type: " << write;
534 if (!memRefType.isLastDimUnitStride())
536 diag <<
"most minor stride is not 1: " << write;
541 auto memrefElTy = memRefType.getElementType();
542 if (isa<VectorType>(memrefElTy) && memrefElTy != write.getVectorType())
544 diag <<
"elemental type mismatch: " << write;
548 if (!isa<VectorType>(memrefElTy) &&
549 memrefElTy != write.getVectorType().getElementType())
551 diag <<
"elemental type mismatch: " << write;
555 if (write.hasOutOfBoundsDim())
557 diag <<
"out of bounds dim: " << write;
559 if (write.getMask()) {
560 if (write.getVectorType().getRank() != 1)
564 diag <<
"vector type is not rank 1, can't create masked store, "
565 "needs VectorToSCF: "
569 vector::MaskedStoreOp::create(rewriter, write.getLoc(), write.getBase(),
570 write.getIndices(), write.getMask(),
573 vector::StoreOp::create(rewriter, write.getLoc(), write.getVector(),
574 write.getBase(), write.getIndices());
581 std::optional<unsigned> maxTransferRank;
588 patterns.add<TransferReadToVectorLoadLowering,
589 TransferWriteToVectorStoreLowering>(
patterns.getContext(),
590 maxTransferRank, benefit);
static ArrayAttr inverseTransposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr, const SmallVector< unsigned > &permutation)
Transpose a vector transfer op's in_bounds attribute by applying reverse permutation based on the giv...
static Value extendMaskRank(OpBuilder &builder, Location loc, Value vec, int64_t addedRank)
Extend the rank of a vector Value by addedRanks by adding inner unit dimensions.
static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec, int64_t addedRank)
Extend the rank of a vector Value by addedRanks by adding outer unit dimensions.
static std::string diag(const llvm::Value &value)
static std::optional< VectorShape > vectorShape(Type type)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > *broadcastedDims=nullptr) const
Returns true if this affine map is a minor identity up to broadcasted dimensions which are indicated ...
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
unsigned getNumResults() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isIdentity() const
Returns true if this affine map is an identity affine map.
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, std::optional< unsigned > maxTransferRank=std::nullopt, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
const FrozenRewritePatternSet & patterns
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.