26 #define DEBUG_TYPE "vector-shape-cast-lowering"
51 for (
int dim : llvm::reverse(llvm::seq<int>(0, rhs.size()))) {
52 int64_t dimBase = base[dim];
53 assert(rhs[dim] < dimBase &&
"rhs not in base");
55 int64_t incremented = rhs[dim] + lhs;
60 lhs = incremented / dimBase;
61 rhs[dim] = incremented % dimBase;
122 class ShapeCastOpRewritePattern :
public OpRewritePattern<vector::ShapeCastOp> {
126 static LogicalResult leadingOnesLowering(vector::ShapeCastOp shapeCast,
129 const Location loc = shapeCast.getLoc();
130 const VectorType sourceType = shapeCast.getSourceVectorType();
131 const VectorType resultType = shapeCast.getResultVectorType();
133 const int64_t sourceRank = sourceType.getRank();
134 const int64_t resultRank = resultType.getRank();
135 const int64_t delta = sourceRank - resultRank;
136 const int64_t sourceLeading = delta > 0 ? delta : 0;
137 const int64_t resultLeading = delta > 0 ? 0 : -delta;
139 const Value source = shapeCast.getSource();
140 const Value poison = rewriter.
create<ub::PoisonOp>(loc, resultType);
141 const Value extracted = rewriter.
create<vector::ExtractOp>(
143 const Value result = rewriter.
create<vector::InsertOp>(
154 static LogicalResult noStridedSliceLowering(vector::ShapeCastOp shapeCast,
159 const Location loc = shapeCast.getLoc();
161 const Value source = shapeCast.getSource();
163 shapeCast.getSourceVectorType().getShape();
165 const VectorType resultType = shapeCast.getResultVectorType();
168 const int64_t nSlices =
169 std::accumulate(sourceShape.begin(), sourceShape.begin() + sourceDim, 1,
170 std::multiplies<int64_t>());
174 Value result = rewriter.
create<ub::PoisonOp>(loc, resultType);
176 for (
int i = 0; i < nSlices; ++i) {
178 rewriter.
create<vector::ExtractOp>(loc, source, extractIndex);
180 result = rewriter.
create<vector::InsertOp>(loc, extracted, result,
183 inplaceAdd(1, sourceShape.take_front(sourceDim), extractIndex);
184 inplaceAdd(1, resultShape.take_front(resultDim), insertIndex);
193 LogicalResult matchAndRewrite(vector::ShapeCastOp op,
196 VectorType sourceType = op.getSourceVectorType();
197 VectorType resultType = op.getResultVectorType();
199 if (sourceType.isScalable() || resultType.isScalable())
202 "shape_cast where vectors are scalable not handled by this pattern");
206 const int64_t sourceRank = sourceType.getRank();
207 const int64_t resultRank = resultType.getRank();
208 const int64_t numElms = sourceType.getNumElements();
209 const Value source = op.getSource();
221 int64_t sourceSuffixStartDim = sourceRank - 1;
222 int64_t resultSuffixStartDim = resultRank - 1;
223 while (sourceSuffixStartDim >= 0 && resultSuffixStartDim >= 0 &&
224 (sourceType.getDimSize(sourceSuffixStartDim) ==
225 resultType.getDimSize(resultSuffixStartDim))) {
226 --sourceSuffixStartDim;
227 --resultSuffixStartDim;
233 if (resultSuffixStartDim < 0 || sourceSuffixStartDim < 0)
234 return leadingOnesLowering(op, rewriter);
236 const int64_t sourceSuffixStartDimSize =
237 sourceType.getDimSize(sourceSuffixStartDim);
238 const int64_t resultSuffixStartDimSize =
239 resultType.getDimSize(resultSuffixStartDim);
240 const int64_t greatestCommonDivisor =
241 std::gcd(sourceSuffixStartDimSize, resultSuffixStartDimSize);
242 const int64_t stridedSliceRank = sourceRank - sourceSuffixStartDim;
243 const size_t extractPeriod =
244 sourceSuffixStartDimSize / greatestCommonDivisor;
245 const size_t insertPeriod =
246 resultSuffixStartDimSize / greatestCommonDivisor;
250 atomicShape[0] = greatestCommonDivisor;
252 const int64_t numAtomicElms = std::accumulate(
253 atomicShape.begin(), atomicShape.end(), 1, std::multiplies<int64_t>());
254 const size_t nAtomicSlices = numElms / numAtomicElms;
260 if (greatestCommonDivisor == 1)
261 return noStridedSliceLowering(op, sourceSuffixStartDim + 1,
262 resultSuffixStartDim + 1, rewriter);
266 resultShape.drop_front(resultSuffixStartDim);
267 const VectorType insertStridedType =
276 Value extracted = {};
277 Value extractedStrided = {};
278 Value insertedSlice = {};
279 Value result = rewriter.
create<ub::PoisonOp>(loc, resultType);
280 const Value partResult =
281 rewriter.
create<ub::PoisonOp>(loc, insertStridedType);
283 for (
size_t i = 0; i < nAtomicSlices; ++i) {
285 const size_t extractStridedPhase = i % extractPeriod;
286 const size_t insertStridedPhase = i % insertPeriod;
289 if (extractStridedPhase == 0) {
291 rewriter.
create<vector::ExtractOp>(loc, source, extractIndex);
292 inplaceAdd(1, sourceShape.take_front(sourceSuffixStartDim),
297 extractOffsets[0] = extractStridedPhase * greatestCommonDivisor;
298 extractedStrided = rewriter.
create<vector::ExtractStridedSliceOp>(
299 loc, extracted, extractOffsets, atomicShape, sizes);
302 if (insertStridedPhase == 0) {
303 insertedSlice = partResult;
305 insertOffsets[0] = insertStridedPhase * greatestCommonDivisor;
306 insertedSlice = rewriter.
create<vector::InsertStridedSliceOp>(
307 loc, extractedStrided, insertedSlice, insertOffsets, sizes);
310 if (insertStridedPhase + 1 == insertPeriod) {
311 result = rewriter.
create<vector::InsertOp>(loc, insertedSlice, result,
313 inplaceAdd(1, resultType.getShape().take_front(resultSuffixStartDim),
355 class ScalableShapeCastOpRewritePattern
360 LogicalResult matchAndRewrite(vector::ShapeCastOp op,
364 auto sourceVectorType = op.getSourceVectorType();
365 auto resultVectorType = op.getResultVectorType();
366 auto srcRank = sourceVectorType.getRank();
367 auto resRank = resultVectorType.getRank();
374 if (!isTrailingDimScalable(sourceVectorType) ||
375 !isTrailingDimScalable(resultVectorType)) {
377 op,
"trailing dims are not scalable, not handled by this pattern");
383 auto minSourceTrailingSize = sourceVectorType.getShape().back();
384 auto minResultTrailingSize = resultVectorType.getShape().back();
385 auto minExtractionSize =
386 std::min(minSourceTrailingSize, minResultTrailingSize);
387 int64_t minNumElts = 1;
388 for (
auto size : sourceVectorType.getShape())
395 {minExtractionSize}, sourceVectorType.getElementType(), {
true});
397 Value result = rewriter.
create<ub::PoisonOp>(loc, resultVectorType);
403 Value currentResultScalableVector;
404 Value currentSourceScalableVector;
405 for (int64_t i = 0; i < minNumElts; i += minExtractionSize) {
407 if (!currentSourceScalableVector) {
409 currentSourceScalableVector = rewriter.
create<vector::ExtractOp>(
412 currentSourceScalableVector = op.getSource();
415 Value sourceSubVector = currentSourceScalableVector;
416 if (minExtractionSize < minSourceTrailingSize) {
417 sourceSubVector = rewriter.
create<vector::ScalableExtractOp>(
418 loc, extractionVectorType, sourceSubVector, srcIdx.back());
422 if (!currentResultScalableVector) {
423 if (minExtractionSize == minResultTrailingSize) {
424 currentResultScalableVector = sourceSubVector;
425 }
else if (resRank != 1) {
426 currentResultScalableVector = rewriter.
create<vector::ExtractOp>(
429 currentResultScalableVector = result;
432 if (minExtractionSize < minResultTrailingSize) {
433 currentResultScalableVector = rewriter.
create<vector::ScalableInsertOp>(
434 loc, sourceSubVector, currentResultScalableVector, resIdx.back());
438 if (resIdx.back() + minExtractionSize >= minResultTrailingSize &&
439 currentResultScalableVector != result) {
442 result = rewriter.
create<vector::InsertOp>(
443 loc, currentResultScalableVector, result,
445 currentResultScalableVector = {};
447 if (srcIdx.back() + minExtractionSize >= minSourceTrailingSize) {
449 currentSourceScalableVector = {};
454 inplaceAdd(minExtractionSize, sourceVectorType.getShape(), srcIdx);
455 inplaceAdd(minExtractionSize, resultVectorType.getShape(), resIdx);
462 static bool isTrailingDimScalable(VectorType type) {
463 return type.getRank() >= 1 && type.getScalableDims().back() &&
464 !llvm::is_contained(type.getScalableDims().drop_back(),
true);
472 patterns.add<ShapeCastOpRewritePattern, ScalableShapeCastOpRewritePattern>(
static void inplaceAdd(int64_t lhs, ArrayRef< int64_t > base, MutableArrayRef< int64_t > rhs)
Perform the inplace update rhs <- lhs + rhs.
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...