26 #define DEBUG_TYPE "vector-shape-cast-lowering"
51 for (
int dim : llvm::reverse(llvm::seq<int>(0, rhs.size()))) {
52 int64_t dimBase = base[dim];
53 assert(rhs[dim] < dimBase &&
"rhs not in base");
55 int64_t incremented = rhs[dim] + lhs;
60 lhs = incremented / dimBase;
61 rhs[dim] = incremented % dimBase;
122 class ShapeCastOpRewritePattern :
public OpRewritePattern<vector::ShapeCastOp> {
126 static LogicalResult leadingOnesLowering(vector::ShapeCastOp shapeCast,
129 const Location loc = shapeCast.getLoc();
130 const VectorType sourceType = shapeCast.getSourceVectorType();
131 const VectorType resultType = shapeCast.getResultVectorType();
133 const int64_t sourceRank = sourceType.getRank();
134 const int64_t resultRank = resultType.getRank();
135 const int64_t delta = sourceRank - resultRank;
136 const int64_t sourceLeading = delta > 0 ? delta : 0;
137 const int64_t resultLeading = delta > 0 ? 0 : -delta;
139 const Value source = shapeCast.getSource();
140 const Value poison = ub::PoisonOp::create(rewriter, loc, resultType);
141 const Value extracted = vector::ExtractOp::create(
144 vector::InsertOp::create(rewriter, loc, extracted, poison,
155 static LogicalResult noStridedSliceLowering(vector::ShapeCastOp shapeCast,
160 const Location loc = shapeCast.getLoc();
162 const Value source = shapeCast.getSource();
164 shapeCast.getSourceVectorType().getShape();
166 const VectorType resultType = shapeCast.getResultVectorType();
169 const int64_t nSlices =
170 std::accumulate(sourceShape.begin(), sourceShape.begin() + sourceDim, 1,
171 std::multiplies<int64_t>());
175 Value result = ub::PoisonOp::create(rewriter, loc, resultType);
177 for (
int i = 0; i < nSlices; ++i) {
179 vector::ExtractOp::create(rewriter, loc, source, extractIndex);
181 result = vector::InsertOp::create(rewriter, loc, extracted, result,
184 inplaceAdd(1, sourceShape.take_front(sourceDim), extractIndex);
185 inplaceAdd(1, resultShape.take_front(resultDim), insertIndex);
194 LogicalResult matchAndRewrite(vector::ShapeCastOp op,
197 VectorType sourceType = op.getSourceVectorType();
198 VectorType resultType = op.getResultVectorType();
200 if (sourceType.isScalable() || resultType.isScalable())
203 "shape_cast where vectors are scalable not handled by this pattern");
207 const int64_t sourceRank = sourceType.getRank();
208 const int64_t resultRank = resultType.getRank();
209 const int64_t numElms = sourceType.getNumElements();
210 const Value source = op.getSource();
222 int64_t sourceSuffixStartDim = sourceRank - 1;
223 int64_t resultSuffixStartDim = resultRank - 1;
224 while (sourceSuffixStartDim >= 0 && resultSuffixStartDim >= 0 &&
225 (sourceType.getDimSize(sourceSuffixStartDim) ==
226 resultType.getDimSize(resultSuffixStartDim))) {
227 --sourceSuffixStartDim;
228 --resultSuffixStartDim;
234 if (resultSuffixStartDim < 0 || sourceSuffixStartDim < 0)
235 return leadingOnesLowering(op, rewriter);
237 const int64_t sourceSuffixStartDimSize =
238 sourceType.getDimSize(sourceSuffixStartDim);
239 const int64_t resultSuffixStartDimSize =
240 resultType.getDimSize(resultSuffixStartDim);
241 const int64_t greatestCommonDivisor =
242 std::gcd(sourceSuffixStartDimSize, resultSuffixStartDimSize);
243 const int64_t stridedSliceRank = sourceRank - sourceSuffixStartDim;
244 const size_t extractPeriod =
245 sourceSuffixStartDimSize / greatestCommonDivisor;
246 const size_t insertPeriod =
247 resultSuffixStartDimSize / greatestCommonDivisor;
251 atomicShape[0] = greatestCommonDivisor;
253 const int64_t numAtomicElms = std::accumulate(
254 atomicShape.begin(), atomicShape.end(), 1, std::multiplies<int64_t>());
255 const size_t nAtomicSlices = numElms / numAtomicElms;
261 if (greatestCommonDivisor == 1)
262 return noStridedSliceLowering(op, sourceSuffixStartDim + 1,
263 resultSuffixStartDim + 1, rewriter);
267 resultShape.drop_front(resultSuffixStartDim);
268 const VectorType insertStridedType =
277 Value extracted = {};
278 Value extractedStrided = {};
279 Value insertedSlice = {};
280 Value result = ub::PoisonOp::create(rewriter, loc, resultType);
281 const Value partResult =
282 ub::PoisonOp::create(rewriter, loc, insertStridedType);
284 for (
size_t i = 0; i < nAtomicSlices; ++i) {
286 const size_t extractStridedPhase = i % extractPeriod;
287 const size_t insertStridedPhase = i % insertPeriod;
290 if (extractStridedPhase == 0) {
292 vector::ExtractOp::create(rewriter, loc, source, extractIndex);
293 inplaceAdd(1, sourceShape.take_front(sourceSuffixStartDim),
298 extractOffsets[0] = extractStridedPhase * greatestCommonDivisor;
299 extractedStrided = vector::ExtractStridedSliceOp::create(
300 rewriter, loc, extracted, extractOffsets, atomicShape, sizes);
303 if (insertStridedPhase == 0) {
304 insertedSlice = partResult;
306 insertOffsets[0] = insertStridedPhase * greatestCommonDivisor;
307 insertedSlice = vector::InsertStridedSliceOp::create(
308 rewriter, loc, extractedStrided, insertedSlice, insertOffsets, sizes);
311 if (insertStridedPhase + 1 == insertPeriod) {
312 result = vector::InsertOp::create(rewriter, loc, insertedSlice, result,
314 inplaceAdd(1, resultType.getShape().take_front(resultSuffixStartDim),
356 class ScalableShapeCastOpRewritePattern
361 LogicalResult matchAndRewrite(vector::ShapeCastOp op,
365 auto sourceVectorType = op.getSourceVectorType();
366 auto resultVectorType = op.getResultVectorType();
367 auto srcRank = sourceVectorType.getRank();
368 auto resRank = resultVectorType.getRank();
375 if (!isTrailingDimScalable(sourceVectorType) ||
376 !isTrailingDimScalable(resultVectorType)) {
378 op,
"trailing dims are not scalable, not handled by this pattern");
384 auto minSourceTrailingSize = sourceVectorType.getShape().back();
385 auto minResultTrailingSize = resultVectorType.getShape().back();
386 auto minExtractionSize =
387 std::min(minSourceTrailingSize, minResultTrailingSize);
388 int64_t minNumElts = 1;
389 for (
auto size : sourceVectorType.getShape())
396 {minExtractionSize}, sourceVectorType.getElementType(), {
true});
398 Value result = ub::PoisonOp::create(rewriter, loc, resultVectorType);
404 Value currentResultScalableVector;
405 Value currentSourceScalableVector;
406 for (int64_t i = 0; i < minNumElts; i += minExtractionSize) {
408 if (!currentSourceScalableVector) {
410 currentSourceScalableVector =
411 vector::ExtractOp::create(rewriter, loc, op.getSource(),
414 currentSourceScalableVector = op.getSource();
417 Value sourceSubVector = currentSourceScalableVector;
418 if (minExtractionSize < minSourceTrailingSize) {
419 sourceSubVector = vector::ScalableExtractOp::create(
420 rewriter, loc, extractionVectorType, sourceSubVector,
425 if (!currentResultScalableVector) {
426 if (minExtractionSize == minResultTrailingSize) {
427 currentResultScalableVector = sourceSubVector;
428 }
else if (resRank != 1) {
429 currentResultScalableVector = vector::ExtractOp::create(
432 currentResultScalableVector = result;
435 if (minExtractionSize < minResultTrailingSize) {
436 currentResultScalableVector = vector::ScalableInsertOp::create(
437 rewriter, loc, sourceSubVector, currentResultScalableVector,
442 if (resIdx.back() + minExtractionSize >= minResultTrailingSize &&
443 currentResultScalableVector != result) {
446 result = vector::InsertOp::create(rewriter, loc,
447 currentResultScalableVector, result,
449 currentResultScalableVector = {};
451 if (srcIdx.back() + minExtractionSize >= minSourceTrailingSize) {
453 currentSourceScalableVector = {};
458 inplaceAdd(minExtractionSize, sourceVectorType.getShape(), srcIdx);
459 inplaceAdd(minExtractionSize, resultVectorType.getShape(), resIdx);
466 static bool isTrailingDimScalable(VectorType type) {
467 return type.getRank() >= 1 && type.getScalableDims().back() &&
468 !llvm::is_contained(type.getScalableDims().drop_back(),
true);
476 patterns.add<ShapeCastOpRewritePattern, ScalableShapeCastOpRewritePattern>(
static void inplaceAdd(int64_t lhs, ArrayRef< int64_t > base, MutableArrayRef< int64_t > rhs)
Perform the inplace update rhs <- lhs + rhs.
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...