17 #define DEBUG_TYPE "vector-drop-unit-dim"
27 oldShape.drop_while([](int64_t dim) {
return dim == 1; });
30 newShape = oldShape.take_back();
31 return VectorType::get(newShape, oldType.getElementType());
42 struct CastAwayExtractStridedSliceLeadingOneDim
46 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
51 VectorType oldSrcType = extractOp.getSourceVectorType();
54 if (newSrcType.getRank() == oldSrcType.getRank())
57 int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank();
59 VectorType oldDstType = extractOp.getType();
60 VectorType newDstType =
61 VectorType::get(oldDstType.getShape().drop_front(dropCount),
62 oldDstType.getElementType());
66 Value newSrcVector = rewriter.
create<vector::ExtractOp>(
67 loc, extractOp.getVector(),
splatZero(dropCount));
72 extractOp.getOffsets().getValue().drop_front(dropCount));
74 extractOp.getSizes().getValue().drop_front(dropCount));
76 extractOp.getStrides().getValue().drop_front(dropCount));
78 auto newExtractOp = rewriter.
create<vector::ExtractStridedSliceOp>(
79 loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides);
90 struct CastAwayInsertStridedSliceLeadingOneDim
94 LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp,
96 VectorType oldSrcType = insertOp.getSourceVectorType();
98 VectorType oldDstType = insertOp.getDestVectorType();
101 int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank();
102 int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
103 if (srcDropCount == 0 && dstDropCount == 0)
109 Value newSrcVector = rewriter.
create<vector::ExtractOp>(
110 loc, insertOp.getSource(),
splatZero(srcDropCount));
111 Value newDstVector = rewriter.
create<vector::ExtractOp>(
112 loc, insertOp.getDest(),
splatZero(dstDropCount));
115 insertOp.getOffsets().getValue().take_back(newDstType.getRank()));
117 insertOp.getStrides().getValue().take_back(newSrcType.getRank()));
119 auto newInsertOp = rewriter.
create<vector::InsertStridedSliceOp>(
120 loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides);
131 struct CastAwayInsertLeadingOneDim :
public OpRewritePattern<vector::InsertOp> {
136 Type oldSrcType = insertOp.getSourceType();
137 Type newSrcType = oldSrcType;
138 int64_t oldSrcRank = 0, newSrcRank = 0;
139 if (
auto type = oldSrcType.
dyn_cast<VectorType>()) {
141 oldSrcRank = type.getRank();
142 newSrcRank = newSrcType.
cast<VectorType>().getRank();
145 VectorType oldDstType = insertOp.getDestVectorType();
148 int64_t srcDropCount = oldSrcRank - newSrcRank;
149 int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
150 if (srcDropCount == 0 && dstDropCount == 0)
156 Value newSrcVector = insertOp.getSource();
157 if (oldSrcRank != 0) {
158 newSrcVector = rewriter.
create<vector::ExtractOp>(
159 loc, insertOp.getSource(),
splatZero(srcDropCount));
161 Value newDstVector = rewriter.
create<vector::ExtractOp>(
162 loc, insertOp.getDest(),
splatZero(dstDropCount));
164 unsigned oldPosRank = insertOp.getPosition().getValue().size();
165 unsigned newPosRank = newDstType.getRank() - newSrcRank;
167 insertOp.getPosition().getValue().take_back(newPosRank));
168 if (newPosRank > oldPosRank) {
170 newPositions.resize(newPosRank, zeroAttr);
173 auto newInsertOp = rewriter.
create<vector::InsertOp>(
174 loc, newDstType, newSrcVector, newDstVector,
187 struct CastAwayTransferReadLeadingOneDim
194 if (read.getTransferRank() == 0)
200 auto shapedType = read.getSource().getType().cast<ShapedType>();
201 if (shapedType.getElementType() != read.getVectorType().getElementType())
204 VectorType oldType = read.getVectorType();
207 if (newType == oldType)
212 oldMap.
getResults().take_back(newType.getRank());
217 ArrayAttr inBoundsAttr;
218 if (read.getInBounds())
220 read.getInBoundsAttr().getValue().take_back(newType.getRank()));
222 auto newRead = rewriter.
create<vector::TransferReadOp>(
223 read.getLoc(), newType, read.getSource(), read.getIndices(),
224 AffineMapAttr::get(newMap), read.getPadding(),
Value(),
235 struct CastAwayTransferWriteLeadingOneDim
242 if (write.getTransferRank() == 0)
248 auto shapedType = write.getSource().getType().dyn_cast<ShapedType>();
249 if (shapedType.getElementType() != write.getVectorType().getElementType())
252 VectorType oldType = write.getVectorType();
254 if (newType == oldType)
256 int64_t dropDim = oldType.getRank() - newType.getRank();
260 oldMap.
getResults().take_back(newType.getRank());
265 ArrayAttr inBoundsAttr;
266 if (write.getInBounds())
268 write.getInBoundsAttr().getValue().take_back(newType.getRank()));
270 auto newVector = rewriter.
create<vector::ExtractOp>(
271 write.getLoc(), write.getVector(),
splatZero(dropDim));
273 write, newVector, write.getSource(), write.getIndices(),
274 AffineMapAttr::get(newMap), inBoundsAttr);
284 struct CastAwayContractionLeadingOneDim
288 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
290 VectorType oldAccType = contractOp.getAccType().dyn_cast<VectorType>();
291 if (oldAccType ==
nullptr)
293 if (oldAccType.getRank() < 2)
296 if (!contractOp.getMasks().empty())
298 if (oldAccType.getShape()[0] != 1)
304 auto oldIndexingMaps = contractOp.getIndexingMapsArray();
307 auto oldIteratorTypes = contractOp.getIteratorTypes();
310 int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0);
317 int64_t currDim = it.index();
318 if (currDim == dimToDrop)
320 newIteratorTypes.push_back(it.value());
324 contractOp.getAcc()};
330 bool validExtract =
false;
332 auto map = it.value();
333 int64_t orginalZeroDim = it.value().getDimPosition(0);
334 if (orginalZeroDim != dimToDrop) {
340 bool tranposeNeeded =
false;
344 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
345 int64_t currDim = map.getDimPosition(i);
346 if (currDim == dimToDrop) {
347 tranposeNeeded =
true;
348 perm.insert(perm.begin(), i);
350 transposeResults.insert(transposeResults.begin(), targetExpr);
354 transposeResults.push_back(targetExpr);
359 if (tranposeNeeded) {
361 contractOp.getContext());
362 operands[it.index()] = rewriter.
create<vector::TransposeOp>(
363 contractOp.getLoc(), operands[it.index()], perm);
370 if (map.getDimPosition(0) == dimToDrop)
373 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
374 int64_t currDim = map.getDimPosition(i);
375 if (currDim == dimToDrop)
379 currDim < dimToDrop ? currDim : currDim - 1);
380 results.push_back(targetExpr);
382 newIndexingMaps.push_back(
AffineMap::get(map.getNumDims() - 1, 0, results,
383 contractOp.getContext()));
386 newOperands.push_back(validExtract
387 ? rewriter.
create<vector::ExtractOp>(
388 contractOp.getLoc(), operands[it.index()],
390 : operands[it.index()]);
392 auto newContractOp = rewriter.
create<vector::ContractionOp>(
393 contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2],
395 rewriter.
getArrayAttr(newIteratorTypes), contractOp.getKind());
397 contractOp, contractOp->getResultTypes()[0], newContractOp);
404 CastAwayElementwiseLeadingOneDim(
MLIRContext *context,
416 if (newVecType == vecType)
418 int64_t dropDim = vecType.getRank() - newVecType.getRank();
421 if (
auto opVecType = operand.getType().dyn_cast<VectorType>()) {
422 newOperands.push_back(rewriter.
create<vector::ExtractOp>(
425 newOperands.push_back(operand);
430 newOperands, newVecType, op->
getAttrs());
442 .
add<CastAwayExtractStridedSliceLeadingOneDim,
443 CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim,
444 CastAwayTransferReadLeadingOneDim,
445 CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim,
446 CastAwayContractionLeadingOneDim>(patterns.
getContext(), benefit);
static SmallVector< int64_t > splatZero(int64_t rank)
Return a smallVector of size rank containing all zeros.
static VectorType trimLeadingOneDims(VectorType oldType)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
Attribute getZeroAttr(Type type)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePattern is the common base class for all DAG to DAG replacements.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector.shape_cast folding patterns.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of leading one dimension removal patterns.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...