25 #define DEBUG_TYPE "vector-shape-cast-lowering"
33 for (
int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
34 assert(indices[dim] < vecType.getDimSize(dim) &&
35 "Indices are out of bound");
37 if (indices[dim] < vecType.getDimSize(dim))
50 class ShapeCastOpNDDownCastRewritePattern
55 LogicalResult matchAndRewrite(vector::ShapeCastOp op,
57 auto sourceVectorType = op.getSourceVectorType();
58 auto resultVectorType = op.getResultVectorType();
59 if (sourceVectorType.isScalable() || resultVectorType.isScalable())
62 int64_t srcRank = sourceVectorType.getRank();
63 int64_t resRank = resultVectorType.getRank();
64 if (srcRank < 2 || resRank != 1)
69 for (int64_t dim = 0; dim < srcRank - 1; ++dim)
70 numElts *= sourceVectorType.getDimSize(dim);
72 auto loc = op.getLoc();
75 int64_t extractSize = sourceVectorType.getShape().back();
76 Value result = rewriter.
create<ub::PoisonOp>(loc, resultVectorType);
80 for (int64_t i = 0; i < numElts; ++i) {
82 incIdx(srcIdx, sourceVectorType, 1);
83 incIdx(resIdx, resultVectorType, extractSize);
87 rewriter.
create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
88 result = rewriter.
create<vector::InsertStridedSliceOp>(
103 class ShapeCastOpNDUpCastRewritePattern
108 LogicalResult matchAndRewrite(vector::ShapeCastOp op,
110 auto sourceVectorType = op.getSourceVectorType();
111 auto resultVectorType = op.getResultVectorType();
112 if (sourceVectorType.isScalable() || resultVectorType.isScalable())
115 int64_t srcRank = sourceVectorType.getRank();
116 int64_t resRank = resultVectorType.getRank();
117 if (srcRank != 1 || resRank < 2)
122 for (int64_t dim = 0; dim < resRank - 1; ++dim)
123 numElts *= resultVectorType.getDimSize(dim);
127 auto loc = op.getLoc();
130 int64_t extractSize = resultVectorType.getShape().back();
131 Value result = rewriter.
create<ub::PoisonOp>(loc, resultVectorType);
132 for (int64_t i = 0; i < numElts; ++i) {
134 incIdx(srcIdx, sourceVectorType, extractSize);
135 incIdx(resIdx, resultVectorType, 1);
138 Value extract = rewriter.
create<vector::ExtractStridedSliceOp>(
139 loc, op.getSource(), srcIdx, extractSize,
141 result = rewriter.
create<vector::InsertOp>(loc, extract, result, resIdx);
153 class ShapeCastOpRewritePattern :
public OpRewritePattern<vector::ShapeCastOp> {
157 LogicalResult matchAndRewrite(vector::ShapeCastOp op,
160 auto sourceVectorType = op.getSourceVectorType();
161 auto resultVectorType = op.getResultVectorType();
163 if (sourceVectorType.isScalable() || resultVectorType.isScalable())
167 int64_t srcRank = sourceVectorType.getRank();
168 int64_t resRank = resultVectorType.getRank();
169 if ((srcRank > 1 && resRank == 1) || (srcRank == 1 && resRank > 1))
175 for (int64_t r = 0; r < srcRank; r++)
176 numElts *= sourceVectorType.getDimSize(r);
185 Value result = rewriter.
create<ub::PoisonOp>(loc, resultVectorType);
186 for (int64_t i = 0; i < numElts; i++) {
188 incIdx(srcIdx, sourceVectorType);
189 incIdx(resIdx, resultVectorType);
195 assert(srcIdx.empty() &&
"Unexpected indices for 0-D vector");
196 extract = rewriter.
create<vector::ExtractElementOp>(
197 loc, op.getSourceVectorType().getElementType(), op.getSource());
200 rewriter.
create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
205 assert(resIdx.empty() &&
"Unexpected indices for 0-D vector");
206 result = rewriter.
create<vector::InsertElementOp>(loc, extract, result);
209 rewriter.
create<vector::InsertOp>(loc, extract, result, resIdx);
250 class ScalableShapeCastOpRewritePattern
255 LogicalResult matchAndRewrite(vector::ShapeCastOp op,
259 auto sourceVectorType = op.getSourceVectorType();
260 auto resultVectorType = op.getResultVectorType();
261 auto srcRank = sourceVectorType.getRank();
262 auto resRank = resultVectorType.getRank();
269 if (!isTrailingDimScalable(sourceVectorType) ||
270 !isTrailingDimScalable(resultVectorType)) {
277 auto minSourceTrailingSize = sourceVectorType.getShape().back();
278 auto minResultTrailingSize = resultVectorType.getShape().back();
279 auto minExtractionSize =
280 std::min(minSourceTrailingSize, minResultTrailingSize);
281 int64_t minNumElts = 1;
282 for (
auto size : sourceVectorType.getShape())
289 {minExtractionSize}, sourceVectorType.getElementType(), {
true});
291 Value result = rewriter.
create<ub::PoisonOp>(loc, resultVectorType);
297 Value currentResultScalableVector;
298 Value currentSourceScalableVector;
299 for (int64_t i = 0; i < minNumElts; i += minExtractionSize) {
301 if (!currentSourceScalableVector) {
303 currentSourceScalableVector = rewriter.
create<vector::ExtractOp>(
306 currentSourceScalableVector = op.getSource();
309 Value sourceSubVector = currentSourceScalableVector;
310 if (minExtractionSize < minSourceTrailingSize) {
311 sourceSubVector = rewriter.
create<vector::ScalableExtractOp>(
312 loc, extractionVectorType, sourceSubVector, srcIdx.back());
316 if (!currentResultScalableVector) {
317 if (minExtractionSize == minResultTrailingSize) {
318 currentResultScalableVector = sourceSubVector;
319 }
else if (resRank != 1) {
320 currentResultScalableVector = rewriter.
create<vector::ExtractOp>(
323 currentResultScalableVector = result;
326 if (minExtractionSize < minResultTrailingSize) {
327 currentResultScalableVector = rewriter.
create<vector::ScalableInsertOp>(
328 loc, sourceSubVector, currentResultScalableVector, resIdx.back());
332 if (resIdx.back() + minExtractionSize >= minResultTrailingSize &&
333 currentResultScalableVector != result) {
336 result = rewriter.
create<vector::InsertOp>(
337 loc, currentResultScalableVector, result,
339 currentResultScalableVector = {};
341 if (srcIdx.back() + minExtractionSize >= minSourceTrailingSize) {
343 currentSourceScalableVector = {};
348 incIdx(srcIdx, sourceVectorType, minExtractionSize);
349 incIdx(resIdx, resultVectorType, minExtractionSize);
356 static bool isTrailingDimScalable(VectorType type) {
357 return type.getRank() >= 1 && type.getScalableDims().back() &&
358 !llvm::is_contained(type.getScalableDims().drop_back(),
true);
366 patterns.add<ShapeCastOpNDDownCastRewritePattern,
367 ShapeCastOpNDUpCastRewritePattern, ShapeCastOpRewritePattern,
368 ScalableShapeCastOpRewritePattern>(
patterns.getContext(),
static void incIdx(SmallVectorImpl< int64_t > &indices, VectorType vecType, int step=1)
Increments n-D indices by step starting from the innermost dimension.
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...