36 #define DEBUG_TYPE "vector-shape-cast-lowering"
47 class ShapeCastOp2DDownCastRewritePattern
52 LogicalResult matchAndRewrite(vector::ShapeCastOp op,
54 auto sourceVectorType = op.getSourceVectorType();
55 auto resultVectorType = op.getResultVectorType();
57 if (sourceVectorType.isScalable() || resultVectorType.isScalable())
60 if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
63 auto loc = op.getLoc();
65 loc, resultVectorType, rewriter.
getZeroAttr(resultVectorType));
66 unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
67 for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
68 Value vec = rewriter.
create<vector::ExtractOp>(loc, op.getSource(), i);
69 desc = rewriter.
create<vector::InsertStridedSliceOp>(
71 i * mostMinorVectorSize, 1);
84 class ShapeCastOp2DUpCastRewritePattern
89 LogicalResult matchAndRewrite(vector::ShapeCastOp op,
91 auto sourceVectorType = op.getSourceVectorType();
92 auto resultVectorType = op.getResultVectorType();
94 if (sourceVectorType.isScalable() || resultVectorType.isScalable())
97 if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
100 auto loc = op.getLoc();
102 loc, resultVectorType, rewriter.
getZeroAttr(resultVectorType));
103 unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
104 for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
105 Value vec = rewriter.
create<vector::ExtractStridedSliceOp>(
106 loc, op.getSource(), i * mostMinorVectorSize,
109 desc = rewriter.
create<vector::InsertOp>(loc, vec, desc, i);
117 int dimIdx,
int initialStep = 1) {
118 int step = initialStep;
119 for (
int d = dimIdx; d >= 0; d--) {
121 if (idx[d] >= tp.getDimSize(d)) {
135 class ShapeCastOpRewritePattern :
public OpRewritePattern<vector::ShapeCastOp> {
139 LogicalResult matchAndRewrite(vector::ShapeCastOp op,
142 auto sourceVectorType = op.getSourceVectorType();
143 auto resultVectorType = op.getResultVectorType();
145 if (sourceVectorType.isScalable() || resultVectorType.isScalable())
150 int64_t srcRank = sourceVectorType.getRank();
151 int64_t resRank = resultVectorType.getRank();
152 if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
161 for (int64_t r = 0; r < srcRank; r++)
162 numElts *= sourceVectorType.getDimSize(r);
172 loc, resultVectorType, rewriter.
getZeroAttr(resultVectorType));
173 for (int64_t i = 0; i < numElts; i++) {
175 incIdx(srcIdx, sourceVectorType, srcRank - 1);
176 incIdx(resIdx, resultVectorType, resRank - 1);
182 assert(srcIdx.empty() &&
"Unexpected indices for 0-D vector");
183 extract = rewriter.
create<vector::ExtractElementOp>(
184 loc, op.getSourceVectorType().getElementType(), op.getSource());
187 rewriter.
create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
192 assert(resIdx.empty() &&
"Unexpected indices for 0-D vector");
193 result = rewriter.
create<vector::InsertElementOp>(loc, extract, result);
196 rewriter.
create<vector::InsertOp>(loc, extract, result, resIdx);
237 class ScalableShapeCastOpRewritePattern
242 LogicalResult matchAndRewrite(vector::ShapeCastOp op,
246 auto sourceVectorType = op.getSourceVectorType();
247 auto resultVectorType = op.getResultVectorType();
248 auto srcRank = sourceVectorType.getRank();
249 auto resRank = resultVectorType.getRank();
256 if (!isTrailingDimScalable(sourceVectorType) ||
257 !isTrailingDimScalable(resultVectorType)) {
264 auto minSourceTrailingSize = sourceVectorType.getShape().back();
265 auto minResultTrailingSize = resultVectorType.getShape().back();
266 auto minExtractionSize =
267 std::min(minSourceTrailingSize, minResultTrailingSize);
268 int64_t minNumElts = 1;
269 for (
auto size : sourceVectorType.getShape())
276 {minExtractionSize}, sourceVectorType.getElementType(), {
true});
279 loc, resultVectorType, rewriter.
getZeroAttr(resultVectorType));
286 Value currentResultScalableVector;
287 Value currentSourceScalableVector;
288 for (int64_t i = 0; i < minNumElts; i += minExtractionSize) {
290 if (!currentSourceScalableVector) {
292 currentSourceScalableVector = rewriter.
create<vector::ExtractOp>(
295 currentSourceScalableVector = op.getSource();
298 Value sourceSubVector = currentSourceScalableVector;
299 if (minExtractionSize < minSourceTrailingSize) {
300 sourceSubVector = rewriter.
create<vector::ScalableExtractOp>(
301 loc, extractionVectorType, sourceSubVector, srcIdx.back());
305 if (!currentResultScalableVector) {
306 if (minExtractionSize == minResultTrailingSize) {
307 currentResultScalableVector = sourceSubVector;
308 }
else if (resRank != 1) {
309 currentResultScalableVector = rewriter.
create<vector::ExtractOp>(
312 currentResultScalableVector = result;
315 if (minExtractionSize < minResultTrailingSize) {
316 currentResultScalableVector = rewriter.
create<vector::ScalableInsertOp>(
317 loc, sourceSubVector, currentResultScalableVector, resIdx.back());
321 if (resIdx.back() + minExtractionSize >= minResultTrailingSize &&
322 currentResultScalableVector != result) {
325 result = rewriter.
create<vector::InsertOp>(
326 loc, currentResultScalableVector, result,
328 currentResultScalableVector = {};
330 if (srcIdx.back() + minExtractionSize >= minSourceTrailingSize) {
332 currentSourceScalableVector = {};
337 incIdx(srcIdx, sourceVectorType, srcRank - 1, minExtractionSize);
338 incIdx(resIdx, resultVectorType, resRank - 1, minExtractionSize);
345 static bool isTrailingDimScalable(VectorType type) {
346 return type.getRank() >= 1 && type.getScalableDims().back() &&
347 !llvm::is_contained(type.getScalableDims().drop_back(),
true);
355 patterns.
add<ShapeCastOp2DDownCastRewritePattern,
356 ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern,
357 ScalableShapeCastOpRewritePattern>(patterns.
getContext(),
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
TypedAttr getZeroAttr(Type type)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...