24 #include "llvm/ADT/STLExtras.h"
27 #define DEBUG_TYPE "vector-shape-cast-lowering"
52 for (
int dim : llvm::reverse(llvm::seq<int>(0, rhs.size()))) {
53 int64_t dimBase = base[dim];
54 assert(rhs[dim] < dimBase &&
"rhs not in base");
56 int64_t incremented = rhs[dim] + lhs;
61 lhs = incremented / dimBase;
62 rhs[dim] = incremented % dimBase;
123 class ShapeCastOpRewritePattern :
public OpRewritePattern<vector::ShapeCastOp> {
127 static LogicalResult leadingOnesLowering(vector::ShapeCastOp shapeCast,
130 const Location loc = shapeCast.getLoc();
131 const VectorType sourceType = shapeCast.getSourceVectorType();
132 const VectorType resultType = shapeCast.getResultVectorType();
134 const int64_t sourceRank = sourceType.getRank();
135 const int64_t resultRank = resultType.getRank();
136 const int64_t delta = sourceRank - resultRank;
137 const int64_t sourceLeading = delta > 0 ? delta : 0;
138 const int64_t resultLeading = delta > 0 ? 0 : -delta;
140 const Value source = shapeCast.getSource();
141 const Value poison = ub::PoisonOp::create(rewriter, loc, resultType);
142 const Value extracted = vector::ExtractOp::create(
145 vector::InsertOp::create(rewriter, loc, extracted, poison,
156 static LogicalResult noStridedSliceLowering(vector::ShapeCastOp shapeCast,
161 const Location loc = shapeCast.getLoc();
163 const Value source = shapeCast.getSource();
165 shapeCast.getSourceVectorType().getShape();
167 const VectorType resultType = shapeCast.getResultVectorType();
170 const int64_t nSlices = llvm::product_of(sourceShape.take_front(sourceDim));
173 Value result = ub::PoisonOp::create(rewriter, loc, resultType);
175 for (
int i = 0; i < nSlices; ++i) {
177 vector::ExtractOp::create(rewriter, loc, source, extractIndex);
179 result = vector::InsertOp::create(rewriter, loc, extracted, result,
182 inplaceAdd(1, sourceShape.take_front(sourceDim), extractIndex);
183 inplaceAdd(1, resultShape.take_front(resultDim), insertIndex);
192 LogicalResult matchAndRewrite(vector::ShapeCastOp op,
195 VectorType sourceType = op.getSourceVectorType();
196 VectorType resultType = op.getResultVectorType();
198 if (sourceType.isScalable() || resultType.isScalable())
201 "shape_cast where vectors are scalable not handled by this pattern");
205 const int64_t sourceRank = sourceType.getRank();
206 const int64_t resultRank = resultType.getRank();
207 const int64_t numElms = sourceType.getNumElements();
208 const Value source = op.getSource();
220 int64_t sourceSuffixStartDim = sourceRank - 1;
221 int64_t resultSuffixStartDim = resultRank - 1;
222 while (sourceSuffixStartDim >= 0 && resultSuffixStartDim >= 0 &&
223 (sourceType.getDimSize(sourceSuffixStartDim) ==
224 resultType.getDimSize(resultSuffixStartDim))) {
225 --sourceSuffixStartDim;
226 --resultSuffixStartDim;
232 if (resultSuffixStartDim < 0 || sourceSuffixStartDim < 0)
233 return leadingOnesLowering(op, rewriter);
235 const int64_t sourceSuffixStartDimSize =
236 sourceType.getDimSize(sourceSuffixStartDim);
237 const int64_t resultSuffixStartDimSize =
238 resultType.getDimSize(resultSuffixStartDim);
239 const int64_t greatestCommonDivisor =
240 std::gcd(sourceSuffixStartDimSize, resultSuffixStartDimSize);
241 const int64_t stridedSliceRank = sourceRank - sourceSuffixStartDim;
242 const size_t extractPeriod =
243 sourceSuffixStartDimSize / greatestCommonDivisor;
244 const size_t insertPeriod =
245 resultSuffixStartDimSize / greatestCommonDivisor;
249 atomicShape[0] = greatestCommonDivisor;
251 const int64_t numAtomicElms = std::accumulate(
252 atomicShape.begin(), atomicShape.end(), 1, std::multiplies<int64_t>());
253 const size_t nAtomicSlices = numElms / numAtomicElms;
259 if (greatestCommonDivisor == 1)
260 return noStridedSliceLowering(op, sourceSuffixStartDim + 1,
261 resultSuffixStartDim + 1, rewriter);
265 resultShape.drop_front(resultSuffixStartDim);
266 const VectorType insertStridedType =
275 Value extracted = {};
276 Value extractedStrided = {};
277 Value insertedSlice = {};
278 Value result = ub::PoisonOp::create(rewriter, loc, resultType);
279 const Value partResult =
280 ub::PoisonOp::create(rewriter, loc, insertStridedType);
282 for (
size_t i = 0; i < nAtomicSlices; ++i) {
284 const size_t extractStridedPhase = i % extractPeriod;
285 const size_t insertStridedPhase = i % insertPeriod;
288 if (extractStridedPhase == 0) {
290 vector::ExtractOp::create(rewriter, loc, source, extractIndex);
291 inplaceAdd(1, sourceShape.take_front(sourceSuffixStartDim),
296 extractOffsets[0] = extractStridedPhase * greatestCommonDivisor;
297 extractedStrided = vector::ExtractStridedSliceOp::create(
298 rewriter, loc, extracted, extractOffsets, atomicShape, sizes);
301 if (insertStridedPhase == 0) {
302 insertedSlice = partResult;
304 insertOffsets[0] = insertStridedPhase * greatestCommonDivisor;
305 insertedSlice = vector::InsertStridedSliceOp::create(
306 rewriter, loc, extractedStrided, insertedSlice, insertOffsets, sizes);
309 if (insertStridedPhase + 1 == insertPeriod) {
310 result = vector::InsertOp::create(rewriter, loc, insertedSlice, result,
312 inplaceAdd(1, resultType.getShape().take_front(resultSuffixStartDim),
354 class ScalableShapeCastOpRewritePattern
359 LogicalResult matchAndRewrite(vector::ShapeCastOp op,
363 auto sourceVectorType = op.getSourceVectorType();
364 auto resultVectorType = op.getResultVectorType();
365 auto srcRank = sourceVectorType.getRank();
366 auto resRank = resultVectorType.getRank();
373 if (!isTrailingDimScalable(sourceVectorType) ||
374 !isTrailingDimScalable(resultVectorType)) {
376 op,
"trailing dims are not scalable, not handled by this pattern");
382 auto minSourceTrailingSize = sourceVectorType.getShape().back();
383 auto minResultTrailingSize = resultVectorType.getShape().back();
384 auto minExtractionSize =
385 std::min(minSourceTrailingSize, minResultTrailingSize);
386 int64_t minNumElts = 1;
387 for (
auto size : sourceVectorType.getShape())
394 {minExtractionSize}, sourceVectorType.getElementType(), {
true});
396 Value result = ub::PoisonOp::create(rewriter, loc, resultVectorType);
402 Value currentResultScalableVector;
403 Value currentSourceScalableVector;
404 for (int64_t i = 0; i < minNumElts; i += minExtractionSize) {
406 if (!currentSourceScalableVector) {
408 currentSourceScalableVector =
409 vector::ExtractOp::create(rewriter, loc, op.getSource(),
412 currentSourceScalableVector = op.getSource();
415 Value sourceSubVector = currentSourceScalableVector;
416 if (minExtractionSize < minSourceTrailingSize) {
417 sourceSubVector = vector::ScalableExtractOp::create(
418 rewriter, loc, extractionVectorType, sourceSubVector,
423 if (!currentResultScalableVector) {
424 if (minExtractionSize == minResultTrailingSize) {
425 currentResultScalableVector = sourceSubVector;
426 }
else if (resRank != 1) {
427 currentResultScalableVector = vector::ExtractOp::create(
430 currentResultScalableVector = result;
433 if (minExtractionSize < minResultTrailingSize) {
434 currentResultScalableVector = vector::ScalableInsertOp::create(
435 rewriter, loc, sourceSubVector, currentResultScalableVector,
440 if (resIdx.back() + minExtractionSize >= minResultTrailingSize &&
441 currentResultScalableVector != result) {
444 result = vector::InsertOp::create(rewriter, loc,
445 currentResultScalableVector, result,
447 currentResultScalableVector = {};
449 if (srcIdx.back() + minExtractionSize >= minSourceTrailingSize) {
451 currentSourceScalableVector = {};
456 inplaceAdd(minExtractionSize, sourceVectorType.getShape(), srcIdx);
457 inplaceAdd(minExtractionSize, resultVectorType.getShape(), resIdx);
464 static bool isTrailingDimScalable(VectorType type) {
465 return type.getRank() >= 1 && type.getScalableDims().back() &&
466 !llvm::is_contained(type.getScalableDims().drop_back(),
true);
474 patterns.add<ShapeCastOpRewritePattern, ScalableShapeCastOpRewritePattern>(
static void inplaceAdd(int64_t lhs, ArrayRef< int64_t > base, MutableArrayRef< int64_t > rhs)
Perform the inplace update rhs <- lhs + rhs.
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...