24#include "llvm/ADT/STLExtras.h"
27#define DEBUG_TYPE "vector-shape-cast-lowering"
52 for (
int dim : llvm::reverse(llvm::seq<int>(0,
rhs.size()))) {
54 assert(
rhs[dim] < dimBase &&
"rhs not in base");
61 lhs = incremented / dimBase;
62 rhs[dim] = incremented % dimBase;
123class ShapeCastOpRewritePattern :
public OpRewritePattern<vector::ShapeCastOp> {
127 static LogicalResult leadingOnesLowering(vector::ShapeCastOp shapeCast,
128 PatternRewriter &rewriter) {
130 const Location loc = shapeCast.getLoc();
131 const VectorType sourceType = shapeCast.getSourceVectorType();
132 const VectorType resultType = shapeCast.getResultVectorType();
134 const int64_t sourceRank = sourceType.getRank();
135 const int64_t resultRank = resultType.getRank();
136 const int64_t delta = sourceRank - resultRank;
137 const int64_t sourceLeading = delta > 0 ? delta : 0;
138 const int64_t resultLeading = delta > 0 ? 0 : -delta;
140 const Value source = shapeCast.getSource();
141 const Value poison = ub::PoisonOp::create(rewriter, loc, resultType);
142 const Value extracted = vector::ExtractOp::create(
143 rewriter, loc, source, SmallVector<int64_t>(sourceLeading, 0));
145 vector::InsertOp::create(rewriter, loc, extracted, poison,
146 SmallVector<int64_t>(resultLeading, 0));
156 static LogicalResult noStridedSliceLowering(vector::ShapeCastOp shapeCast,
159 PatternRewriter &rewriter) {
161 const Location loc = shapeCast.getLoc();
163 const Value source = shapeCast.getSource();
164 const ArrayRef<int64_t> sourceShape =
165 shapeCast.getSourceVectorType().getShape();
167 const VectorType resultType = shapeCast.getResultVectorType();
168 const ArrayRef<int64_t> resultShape = resultType.getShape();
170 const int64_t nSlices = llvm::product_of(sourceShape.take_front(sourceDim));
171 SmallVector<int64_t> extractIndex(sourceDim, 0);
172 SmallVector<int64_t> insertIndex(resultDim, 0);
173 Value
result = ub::PoisonOp::create(rewriter, loc, resultType);
175 for (
int i = 0; i < nSlices; ++i) {
177 vector::ExtractOp::create(rewriter, loc, source, extractIndex);
179 result = vector::InsertOp::create(rewriter, loc, extracted,
result,
182 inplaceAdd(1, sourceShape.take_front(sourceDim), extractIndex);
183 inplaceAdd(1, resultShape.take_front(resultDim), insertIndex);
192 LogicalResult matchAndRewrite(vector::ShapeCastOp op,
193 PatternRewriter &rewriter)
const override {
194 Location loc = op.getLoc();
195 VectorType sourceType = op.getSourceVectorType();
196 VectorType resultType = op.getResultVectorType();
198 if (sourceType.isScalable() || resultType.isScalable())
201 "shape_cast where vectors are scalable not handled by this pattern");
203 const ArrayRef<int64_t> sourceShape = sourceType.getShape();
204 const ArrayRef<int64_t> resultShape = resultType.getShape();
205 const int64_t sourceRank = sourceType.getRank();
206 const int64_t resultRank = resultType.getRank();
207 const int64_t numElms = sourceType.getNumElements();
208 const Value source = op.getSource();
220 int64_t sourceSuffixStartDim = sourceRank - 1;
221 int64_t resultSuffixStartDim = resultRank - 1;
222 while (sourceSuffixStartDim >= 0 && resultSuffixStartDim >= 0 &&
223 (sourceType.getDimSize(sourceSuffixStartDim) ==
224 resultType.getDimSize(resultSuffixStartDim))) {
225 --sourceSuffixStartDim;
226 --resultSuffixStartDim;
232 if (resultSuffixStartDim < 0 || sourceSuffixStartDim < 0)
233 return leadingOnesLowering(op, rewriter);
235 const int64_t sourceSuffixStartDimSize =
236 sourceType.getDimSize(sourceSuffixStartDim);
237 const int64_t resultSuffixStartDimSize =
238 resultType.getDimSize(resultSuffixStartDim);
239 const int64_t greatestCommonDivisor =
240 std::gcd(sourceSuffixStartDimSize, resultSuffixStartDimSize);
241 const int64_t stridedSliceRank = sourceRank - sourceSuffixStartDim;
242 const size_t extractPeriod =
243 sourceSuffixStartDimSize / greatestCommonDivisor;
244 const size_t insertPeriod =
245 resultSuffixStartDimSize / greatestCommonDivisor;
247 SmallVector<int64_t> atomicShape(sourceShape.begin() + sourceSuffixStartDim,
249 atomicShape[0] = greatestCommonDivisor;
251 const int64_t numAtomicElms = std::accumulate(
252 atomicShape.begin(), atomicShape.end(), 1, std::multiplies<int64_t>());
253 const size_t nAtomicSlices = numElms / numAtomicElms;
259 if (greatestCommonDivisor == 1)
260 return noStridedSliceLowering(op, sourceSuffixStartDim + 1,
261 resultSuffixStartDim + 1, rewriter);
264 const ArrayRef<int64_t> insertStridedShape =
265 resultShape.drop_front(resultSuffixStartDim);
266 const VectorType insertStridedType =
267 VectorType::get(insertStridedShape, resultType.getElementType());
269 SmallVector<int64_t> extractIndex(sourceSuffixStartDim, 0);
270 SmallVector<int64_t> insertIndex(resultSuffixStartDim, 0);
271 SmallVector<int64_t> extractOffsets(stridedSliceRank, 0);
272 SmallVector<int64_t> insertOffsets(stridedSliceRank, 0);
273 const SmallVector<int64_t> sizes(stridedSliceRank, 1);
275 Value extracted = {};
276 Value extractedStrided = {};
277 Value insertedSlice = {};
278 Value
result = ub::PoisonOp::create(rewriter, loc, resultType);
279 const Value partResult =
280 ub::PoisonOp::create(rewriter, loc, insertStridedType);
282 for (
size_t i = 0; i < nAtomicSlices; ++i) {
284 const size_t extractStridedPhase = i % extractPeriod;
285 const size_t insertStridedPhase = i % insertPeriod;
288 if (extractStridedPhase == 0) {
290 vector::ExtractOp::create(rewriter, loc, source, extractIndex);
291 inplaceAdd(1, sourceShape.take_front(sourceSuffixStartDim),
296 extractOffsets[0] = extractStridedPhase * greatestCommonDivisor;
297 extractedStrided = vector::ExtractStridedSliceOp::create(
298 rewriter, loc, extracted, extractOffsets, atomicShape, sizes);
301 if (insertStridedPhase == 0) {
302 insertedSlice = partResult;
304 insertOffsets[0] = insertStridedPhase * greatestCommonDivisor;
305 insertedSlice = vector::InsertStridedSliceOp::create(
306 rewriter, loc, extractedStrided, insertedSlice, insertOffsets, sizes);
309 if (insertStridedPhase + 1 == insertPeriod) {
310 result = vector::InsertOp::create(rewriter, loc, insertedSlice,
result,
312 inplaceAdd(1, resultType.getShape().take_front(resultSuffixStartDim),
354class ScalableShapeCastOpRewritePattern
359 LogicalResult matchAndRewrite(vector::ShapeCastOp op,
360 PatternRewriter &rewriter)
const override {
362 Location loc = op.getLoc();
363 auto sourceVectorType = op.getSourceVectorType();
364 auto resultVectorType = op.getResultVectorType();
365 auto srcRank = sourceVectorType.getRank();
366 auto resRank = resultVectorType.getRank();
373 if (!isTrailingDimScalable(sourceVectorType) ||
374 !isTrailingDimScalable(resultVectorType)) {
376 op,
"trailing dims are not scalable, not handled by this pattern");
382 auto minSourceTrailingSize = sourceVectorType.getShape().back();
383 auto minResultTrailingSize = resultVectorType.getShape().back();
384 auto minExtractionSize =
385 std::min(minSourceTrailingSize, minResultTrailingSize);
386 int64_t minNumElts = 1;
387 for (
auto size : sourceVectorType.getShape())
393 auto extractionVectorType = VectorType::get(
394 {minExtractionSize}, sourceVectorType.getElementType(), {
true});
396 Value
result = ub::PoisonOp::create(rewriter, loc, resultVectorType);
397 SmallVector<int64_t> srcIdx(srcRank, 0);
398 SmallVector<int64_t> resIdx(resRank, 0);
402 Value currentResultScalableVector;
403 Value currentSourceScalableVector;
404 for (int64_t i = 0; i < minNumElts; i += minExtractionSize) {
406 if (!currentSourceScalableVector) {
408 currentSourceScalableVector =
409 vector::ExtractOp::create(rewriter, loc, op.getSource(),
410 llvm::ArrayRef(srcIdx).drop_back());
412 currentSourceScalableVector = op.getSource();
415 Value sourceSubVector = currentSourceScalableVector;
416 if (minExtractionSize < minSourceTrailingSize) {
417 sourceSubVector = vector::ScalableExtractOp::create(
418 rewriter, loc, extractionVectorType, sourceSubVector,
423 if (!currentResultScalableVector) {
424 if (minExtractionSize == minResultTrailingSize) {
425 currentResultScalableVector = sourceSubVector;
426 }
else if (resRank != 1) {
427 currentResultScalableVector = vector::ExtractOp::create(
428 rewriter, loc,
result, llvm::ArrayRef(resIdx).drop_back());
430 currentResultScalableVector =
result;
433 if (minExtractionSize < minResultTrailingSize) {
434 currentResultScalableVector = vector::ScalableInsertOp::create(
435 rewriter, loc, sourceSubVector, currentResultScalableVector,
440 if (resIdx.back() + minExtractionSize >= minResultTrailingSize &&
441 currentResultScalableVector !=
result) {
444 result = vector::InsertOp::create(rewriter, loc,
445 currentResultScalableVector,
result,
446 llvm::ArrayRef(resIdx).drop_back());
447 currentResultScalableVector = {};
449 if (srcIdx.back() + minExtractionSize >= minSourceTrailingSize) {
451 currentSourceScalableVector = {};
456 inplaceAdd(minExtractionSize, sourceVectorType.getShape(), srcIdx);
457 inplaceAdd(minExtractionSize, resultVectorType.getShape(), resIdx);
464 static bool isTrailingDimScalable(VectorType type) {
465 return type.getRank() >= 1 && type.getScalableDims().back() &&
466 !llvm::is_contained(type.getScalableDims().drop_back(),
true);
474 patterns.add<ShapeCastOpRewritePattern, ScalableShapeCastOpRewritePattern>(
static void inplaceAdd(int64_t lhs, ArrayRef< int64_t > base, MutableArrayRef< int64_t > rhs)
Perform the inplace update rhs <- lhs + rhs.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...