37 #define DEBUG_TYPE "vector-shape-cast-lowering"
48 class ShapeCastOp2DDownCastRewritePattern
55 auto sourceVectorType = op.getSourceVectorType();
56 auto resultVectorType = op.getResultVectorType();
58 if (sourceVectorType.isScalable() || resultVectorType.isScalable())
61 if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
66 loc, resultVectorType, rewriter.
getZeroAttr(resultVectorType));
67 unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
68 for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
69 Value vec = rewriter.
create<vector::ExtractOp>(loc, op.getSource(), i);
70 desc = rewriter.
create<vector::InsertStridedSliceOp>(
72 i * mostMinorVectorSize, 1);
85 class ShapeCastOp2DUpCastRewritePattern
92 auto sourceVectorType = op.getSourceVectorType();
93 auto resultVectorType = op.getResultVectorType();
95 if (sourceVectorType.isScalable() || resultVectorType.isScalable())
98 if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
103 loc, resultVectorType, rewriter.
getZeroAttr(resultVectorType));
104 unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
105 for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
106 Value vec = rewriter.
create<vector::ExtractStridedSliceOp>(
107 loc, op.getSource(), i * mostMinorVectorSize,
110 desc = rewriter.
create<vector::InsertOp>(loc, vec, desc, i);
118 int dimIdx,
int initialStep = 1) {
119 int step = initialStep;
120 for (
int d = dimIdx; d >= 0; d--) {
122 if (idx[d] >= tp.getDimSize(d)) {
136 class ShapeCastOpRewritePattern :
public OpRewritePattern<vector::ShapeCastOp> {
143 auto sourceVectorType = op.getSourceVectorType();
144 auto resultVectorType = op.getResultVectorType();
146 if (sourceVectorType.isScalable() || resultVectorType.isScalable())
151 int64_t srcRank = sourceVectorType.getRank();
152 int64_t resRank = resultVectorType.getRank();
153 if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
162 for (int64_t r = 0; r < srcRank; r++)
163 numElts *= sourceVectorType.getDimSize(r);
173 loc, resultVectorType, rewriter.
getZeroAttr(resultVectorType));
174 for (int64_t i = 0; i < numElts; i++) {
176 incIdx(srcIdx, sourceVectorType, srcRank - 1);
177 incIdx(resIdx, resultVectorType, resRank - 1);
183 assert(srcIdx.empty() &&
"Unexpected indices for 0-D vector");
184 extract = rewriter.
create<vector::ExtractElementOp>(
185 loc, op.getSourceVectorType().getElementType(), op.getSource());
188 rewriter.
create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
193 assert(resIdx.empty() &&
"Unexpected indices for 0-D vector");
194 result = rewriter.
create<vector::InsertElementOp>(loc, extract, result);
197 rewriter.
create<vector::InsertOp>(loc, extract, result, resIdx);
238 class ScalableShapeCastOpRewritePattern
247 auto sourceVectorType = op.getSourceVectorType();
248 auto resultVectorType = op.getResultVectorType();
249 auto srcRank = sourceVectorType.getRank();
250 auto resRank = resultVectorType.getRank();
257 if (!isTrailingDimScalable(sourceVectorType) ||
258 !isTrailingDimScalable(resultVectorType)) {
265 auto minSourceTrailingSize = sourceVectorType.getShape().back();
266 auto minResultTrailingSize = resultVectorType.getShape().back();
267 auto minExtractionSize =
268 std::min(minSourceTrailingSize, minResultTrailingSize);
269 int64_t minNumElts = 1;
270 for (
auto size : sourceVectorType.getShape())
277 {minExtractionSize}, sourceVectorType.getElementType(), {
true});
280 loc, resultVectorType, rewriter.
getZeroAttr(resultVectorType));
287 Value currentResultScalableVector;
288 Value currentSourceScalableVector;
289 for (int64_t i = 0; i < minNumElts; i += minExtractionSize) {
291 if (!currentSourceScalableVector) {
293 currentSourceScalableVector = rewriter.
create<vector::ExtractOp>(
296 currentSourceScalableVector = op.getSource();
299 Value sourceSubVector = currentSourceScalableVector;
300 if (minExtractionSize < minSourceTrailingSize) {
301 sourceSubVector = rewriter.
create<vector::ScalableExtractOp>(
302 loc, extractionVectorType, sourceSubVector, srcIdx.back());
306 if (!currentResultScalableVector) {
307 if (minExtractionSize == minResultTrailingSize) {
308 currentResultScalableVector = sourceSubVector;
309 }
else if (resRank != 1) {
310 currentResultScalableVector = rewriter.
create<vector::ExtractOp>(
313 currentResultScalableVector = result;
316 if (minExtractionSize < minResultTrailingSize) {
317 currentResultScalableVector = rewriter.
create<vector::ScalableInsertOp>(
318 loc, sourceSubVector, currentResultScalableVector, resIdx.back());
322 if (resIdx.back() + minExtractionSize >= minResultTrailingSize &&
323 currentResultScalableVector != result) {
326 result = rewriter.
create<vector::InsertOp>(
327 loc, currentResultScalableVector, result,
329 currentResultScalableVector = {};
331 if (srcIdx.back() + minExtractionSize >= minSourceTrailingSize) {
333 currentSourceScalableVector = {};
338 incIdx(srcIdx, sourceVectorType, srcRank - 1, minExtractionSize);
339 incIdx(resIdx, resultVectorType, resRank - 1, minExtractionSize);
346 static bool isTrailingDimScalable(VectorType type) {
347 return type.getRank() >= 1 && type.getScalableDims().back() &&
348 !llvm::is_contained(type.getScalableDims().drop_back(),
true);
356 patterns.
add<ShapeCastOp2DDownCastRewritePattern,
357 ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern,
358 ScalableShapeCastOpRewritePattern>(patterns.
getContext(),
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
TypedAttr getZeroAttr(Type type)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Location getLoc()
The source location the operation was defined or derived from.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...