37 #define DEBUG_TYPE "vector-shape-cast-lowering"
48 class ShapeCastOp2DDownCastRewritePattern
55 auto sourceVectorType = op.getSourceVectorType();
56 auto resultVectorType = op.getResultVectorType();
58 if (sourceVectorType.isScalable() || resultVectorType.isScalable())
61 if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
66 loc, resultVectorType, rewriter.
getZeroAttr(resultVectorType));
67 unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
68 for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
69 Value vec = rewriter.
create<vector::ExtractOp>(loc, op.getSource(), i);
70 desc = rewriter.
create<vector::InsertStridedSliceOp>(
72 i * mostMinorVectorSize, 1);
85 class ShapeCastOp2DUpCastRewritePattern
92 auto sourceVectorType = op.getSourceVectorType();
93 auto resultVectorType = op.getResultVectorType();
95 if (sourceVectorType.isScalable() || resultVectorType.isScalable())
98 if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
103 loc, resultVectorType, rewriter.
getZeroAttr(resultVectorType));
104 unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
105 for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
106 Value vec = rewriter.
create<vector::ExtractStridedSliceOp>(
107 loc, op.getSource(), i * mostMinorVectorSize,
110 desc = rewriter.
create<vector::InsertOp>(loc, vec, desc, i);
118 int dimIdx,
int initialStep = 1) {
119 int step = initialStep;
120 for (
int d = dimIdx; d >= 0; d--) {
122 if (idx[d] >= tp.getDimSize(d)) {
136 class ShapeCastOpRewritePattern :
public OpRewritePattern<vector::ShapeCastOp> {
143 auto sourceVectorType = op.getSourceVectorType();
144 auto resultVectorType = op.getResultVectorType();
146 if (sourceVectorType.isScalable() || resultVectorType.isScalable())
151 int64_t srcRank = sourceVectorType.getRank();
152 int64_t resRank = resultVectorType.getRank();
153 if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
162 for (int64_t r = 0; r < srcRank; r++)
163 numElts *= sourceVectorType.getDimSize(r);
173 loc, resultVectorType, rewriter.
getZeroAttr(resultVectorType));
174 for (int64_t i = 0; i < numElts; i++) {
176 incIdx(srcIdx, sourceVectorType, srcRank - 1);
177 incIdx(resIdx, resultVectorType, resRank - 1);
183 assert(srcIdx.empty() &&
"Unexpected indices for 0-D vector");
184 extract = rewriter.
create<vector::ExtractElementOp>(
185 loc, op.getSourceVectorType().getElementType(), op.getSource());
188 rewriter.
create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
193 assert(resIdx.empty() &&
"Unexpected indices for 0-D vector");
194 result = rewriter.
create<vector::InsertElementOp>(loc, extract, result);
197 rewriter.
create<vector::InsertOp>(loc, extract, result, resIdx);
238 class ScalableShapeCastOpRewritePattern
247 auto sourceVectorType = op.getSourceVectorType();
248 auto resultVectorType = op.getResultVectorType();
249 auto srcRank = sourceVectorType.getRank();
250 auto resRank = resultVectorType.getRank();
257 if (!isTrailingDimScalable(sourceVectorType) ||
258 !isTrailingDimScalable(resultVectorType)) {
265 auto minSourceTrailingSize = sourceVectorType.getShape().back();
266 auto minResultTrailingSize = resultVectorType.getShape().back();
267 auto minExtractionSize =
268 std::min(minSourceTrailingSize, minResultTrailingSize);
269 int64_t minNumElts = 1;
270 for (
auto size : sourceVectorType.getShape())
277 {minExtractionSize}, sourceVectorType.getElementType(), {
true});
280 loc, resultVectorType, rewriter.
getZeroAttr(resultVectorType));
287 Value currentResultScalableVector;
288 Value currentSourceScalableVector;
289 for (int64_t i = 0; i < minNumElts; i += minExtractionSize) {
291 if (!currentSourceScalableVector) {
293 currentSourceScalableVector = rewriter.
create<vector::ExtractOp>(
296 currentSourceScalableVector = op.getSource();
299 Value sourceSubVector = currentSourceScalableVector;
300 if (minExtractionSize < minSourceTrailingSize) {
301 sourceSubVector = rewriter.
create<vector::ScalableExtractOp>(
302 loc, extractionVectorType, sourceSubVector, srcIdx.back());
306 if (!currentResultScalableVector) {
307 if (minExtractionSize == minResultTrailingSize) {
308 currentResultScalableVector = sourceSubVector;
309 }
else if (resRank != 1) {
310 currentResultScalableVector = rewriter.
create<vector::ExtractOp>(
313 currentResultScalableVector = result;
316 if (minExtractionSize < minResultTrailingSize) {
317 currentResultScalableVector = rewriter.
create<vector::ScalableInsertOp>(
318 loc, sourceSubVector, currentResultScalableVector, resIdx.back());
322 if (resIdx.back() + minExtractionSize >= minResultTrailingSize &&
323 currentResultScalableVector != result) {
326 result = rewriter.
create<vector::InsertOp>(
327 loc, currentResultScalableVector, result,
329 currentResultScalableVector = {};
331 if (srcIdx.back() + minExtractionSize >= minSourceTrailingSize) {
333 currentSourceScalableVector = {};
338 incIdx(srcIdx, sourceVectorType, srcRank - 1, minExtractionSize);
339 incIdx(resIdx, resultVectorType, resRank - 1, minExtractionSize);
346 static bool isTrailingDimScalable(VectorType type) {
347 return type.getRank() >= 1 && type.getScalableDims().back() &&
348 !llvm::is_contained(type.getScalableDims().drop_back(),
true);
356 patterns.
add<ShapeCastOp2DDownCastRewritePattern,
357 ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern,
358 ScalableShapeCastOpRewritePattern>(patterns.
getContext(),
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
TypedAttr getZeroAttr(Type type)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Location getLoc()
The source location the operation was defined or derived from.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...