25 #define DEBUG_TYPE "vector-shape-cast-lowering"
33 for (
int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
34 assert(indices[dim] < vecType.getDimSize(dim) &&
35 "Indices are out of bound");
37 if (indices[dim] < vecType.getDimSize(dim))
50 class ShapeCastOpNDDownCastRewritePattern
55 LogicalResult matchAndRewrite(vector::ShapeCastOp op,
57 auto sourceVectorType = op.getSourceVectorType();
58 auto resultVectorType = op.getResultVectorType();
59 if (sourceVectorType.isScalable() || resultVectorType.isScalable())
62 int64_t srcRank = sourceVectorType.getRank();
63 int64_t resRank = resultVectorType.getRank();
64 if (srcRank < 2 || resRank != 1)
69 for (int64_t dim = 0; dim < srcRank - 1; ++dim)
70 numElts *= sourceVectorType.getDimSize(dim);
72 auto loc = op.getLoc();
75 int64_t extractSize = sourceVectorType.getShape().back();
76 Value result = rewriter.
create<ub::PoisonOp>(loc, resultVectorType);
80 for (int64_t i = 0; i < numElts; ++i) {
82 incIdx(srcIdx, sourceVectorType, 1);
83 incIdx(resIdx, resultVectorType, extractSize);
87 rewriter.
create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
88 result = rewriter.
create<vector::InsertStridedSliceOp>(
103 class ShapeCastOpNDUpCastRewritePattern
108 LogicalResult matchAndRewrite(vector::ShapeCastOp op,
110 auto sourceVectorType = op.getSourceVectorType();
111 auto resultVectorType = op.getResultVectorType();
112 if (sourceVectorType.isScalable() || resultVectorType.isScalable())
115 int64_t srcRank = sourceVectorType.getRank();
116 int64_t resRank = resultVectorType.getRank();
117 if (srcRank != 1 || resRank < 2)
122 for (int64_t dim = 0; dim < resRank - 1; ++dim)
123 numElts *= resultVectorType.getDimSize(dim);
127 auto loc = op.getLoc();
130 int64_t extractSize = resultVectorType.getShape().back();
131 Value result = rewriter.
create<ub::PoisonOp>(loc, resultVectorType);
132 for (int64_t i = 0; i < numElts; ++i) {
134 incIdx(srcIdx, sourceVectorType, extractSize);
135 incIdx(resIdx, resultVectorType, 1);
138 Value extract = rewriter.
create<vector::ExtractStridedSliceOp>(
139 loc, op.getSource(), srcIdx, extractSize,
141 result = rewriter.
create<vector::InsertOp>(loc, extract, result, resIdx);
153 class ShapeCastOpRewritePattern :
public OpRewritePattern<vector::ShapeCastOp> {
157 LogicalResult matchAndRewrite(vector::ShapeCastOp op,
160 auto sourceVectorType = op.getSourceVectorType();
161 auto resultVectorType = op.getResultVectorType();
163 if (sourceVectorType.isScalable() || resultVectorType.isScalable())
167 int64_t srcRank = sourceVectorType.getRank();
168 int64_t resRank = resultVectorType.getRank();
169 if ((srcRank > 1 && resRank == 1) || (srcRank == 1 && resRank > 1))
175 for (int64_t r = 0; r < srcRank; r++)
176 numElts *= sourceVectorType.getDimSize(r);
185 Value result = rewriter.
create<ub::PoisonOp>(loc, resultVectorType);
186 for (int64_t i = 0; i < numElts; i++) {
188 incIdx(srcIdx, sourceVectorType);
189 incIdx(resIdx, resultVectorType);
193 rewriter.
create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
194 result = rewriter.
create<vector::InsertOp>(loc, extract, result, resIdx);
234 class ScalableShapeCastOpRewritePattern
239 LogicalResult matchAndRewrite(vector::ShapeCastOp op,
243 auto sourceVectorType = op.getSourceVectorType();
244 auto resultVectorType = op.getResultVectorType();
245 auto srcRank = sourceVectorType.getRank();
246 auto resRank = resultVectorType.getRank();
253 if (!isTrailingDimScalable(sourceVectorType) ||
254 !isTrailingDimScalable(resultVectorType)) {
261 auto minSourceTrailingSize = sourceVectorType.getShape().back();
262 auto minResultTrailingSize = resultVectorType.getShape().back();
263 auto minExtractionSize =
264 std::min(minSourceTrailingSize, minResultTrailingSize);
265 int64_t minNumElts = 1;
266 for (
auto size : sourceVectorType.getShape())
273 {minExtractionSize}, sourceVectorType.getElementType(), {
true});
275 Value result = rewriter.
create<ub::PoisonOp>(loc, resultVectorType);
281 Value currentResultScalableVector;
282 Value currentSourceScalableVector;
283 for (int64_t i = 0; i < minNumElts; i += minExtractionSize) {
285 if (!currentSourceScalableVector) {
287 currentSourceScalableVector = rewriter.
create<vector::ExtractOp>(
290 currentSourceScalableVector = op.getSource();
293 Value sourceSubVector = currentSourceScalableVector;
294 if (minExtractionSize < minSourceTrailingSize) {
295 sourceSubVector = rewriter.
create<vector::ScalableExtractOp>(
296 loc, extractionVectorType, sourceSubVector, srcIdx.back());
300 if (!currentResultScalableVector) {
301 if (minExtractionSize == minResultTrailingSize) {
302 currentResultScalableVector = sourceSubVector;
303 }
else if (resRank != 1) {
304 currentResultScalableVector = rewriter.
create<vector::ExtractOp>(
307 currentResultScalableVector = result;
310 if (minExtractionSize < minResultTrailingSize) {
311 currentResultScalableVector = rewriter.
create<vector::ScalableInsertOp>(
312 loc, sourceSubVector, currentResultScalableVector, resIdx.back());
316 if (resIdx.back() + minExtractionSize >= minResultTrailingSize &&
317 currentResultScalableVector != result) {
320 result = rewriter.
create<vector::InsertOp>(
321 loc, currentResultScalableVector, result,
323 currentResultScalableVector = {};
325 if (srcIdx.back() + minExtractionSize >= minSourceTrailingSize) {
327 currentSourceScalableVector = {};
332 incIdx(srcIdx, sourceVectorType, minExtractionSize);
333 incIdx(resIdx, resultVectorType, minExtractionSize);
340 static bool isTrailingDimScalable(VectorType type) {
341 return type.getRank() >= 1 && type.getScalableDims().back() &&
342 !llvm::is_contained(type.getScalableDims().drop_back(),
true);
350 patterns.add<ShapeCastOpNDDownCastRewritePattern,
351 ShapeCastOpNDUpCastRewritePattern, ShapeCastOpRewritePattern,
352 ScalableShapeCastOpRewritePattern>(
patterns.getContext(),
static void incIdx(SmallVectorImpl< int64_t > &indices, VectorType vecType, int step=1)
Increments n-D indices by step starting from the innermost dimension.
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...