17 #define DEBUG_TYPE "vector-drop-unit-dim"
31 while (!newShape.empty() && newShape.front() == 1 &&
32 !newScalableDims.front()) {
33 newShape = newShape.drop_front(1);
34 newScalableDims = newScalableDims.drop_front(1);
38 if (newShape.empty()) {
39 newShape = oldShape.take_back();
40 newScalableDims = oldType.getScalableDims().take_back();
42 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
53 struct CastAwayExtractStridedSliceLeadingOneDim
57 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
62 VectorType oldSrcType = extractOp.getSourceVectorType();
65 if (newSrcType.getRank() == oldSrcType.getRank())
68 int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank();
70 VectorType oldDstType = extractOp.getType();
71 VectorType newDstType =
73 oldDstType.getElementType());
77 Value newSrcVector = rewriter.
create<vector::ExtractOp>(
78 loc, extractOp.getVector(),
splatZero(dropCount));
83 extractOp.getOffsets().getValue().drop_front(dropCount));
85 extractOp.getSizes().getValue().drop_front(dropCount));
87 extractOp.getStrides().getValue().drop_front(dropCount));
89 auto newExtractOp = rewriter.
create<vector::ExtractStridedSliceOp>(
90 loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides);
101 struct CastAwayInsertStridedSliceLeadingOneDim
105 LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp,
107 VectorType oldSrcType = insertOp.getSourceVectorType();
109 VectorType oldDstType = insertOp.getDestVectorType();
112 int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank();
113 int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
114 if (srcDropCount == 0 && dstDropCount == 0)
120 Value newSrcVector = rewriter.
create<vector::ExtractOp>(
121 loc, insertOp.getSource(),
splatZero(srcDropCount));
122 Value newDstVector = rewriter.
create<vector::ExtractOp>(
123 loc, insertOp.getDest(),
splatZero(dstDropCount));
126 insertOp.getOffsets().getValue().take_back(newDstType.getRank()));
128 insertOp.getStrides().getValue().take_back(newSrcType.getRank()));
130 auto newInsertOp = rewriter.
create<vector::InsertStridedSliceOp>(
131 loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides);
142 struct CastAwayInsertLeadingOneDim :
public OpRewritePattern<vector::InsertOp> {
147 Type oldSrcType = insertOp.getSourceType();
148 Type newSrcType = oldSrcType;
149 int64_t oldSrcRank = 0, newSrcRank = 0;
150 if (
auto type = dyn_cast<VectorType>(oldSrcType)) {
152 oldSrcRank = type.getRank();
153 newSrcRank = cast<VectorType>(newSrcType).getRank();
156 VectorType oldDstType = insertOp.getDestVectorType();
159 int64_t srcDropCount = oldSrcRank - newSrcRank;
160 int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
161 if (srcDropCount == 0 && dstDropCount == 0)
167 Value newSrcVector = insertOp.getSource();
168 if (oldSrcRank != 0) {
169 newSrcVector = rewriter.
create<vector::ExtractOp>(
170 loc, insertOp.getSource(),
splatZero(srcDropCount));
172 Value newDstVector = rewriter.
create<vector::ExtractOp>(
173 loc, insertOp.getDest(),
splatZero(dstDropCount));
179 unsigned oldPosRank = insertOp.getPosition().size();
180 unsigned newPosRank = std::max<int64_t>(0, oldPosRank - dstDropCount);
182 llvm::to_vector(insertOp.getPosition().take_back(newPosRank));
183 newPositions.resize(newDstType.getRank() - newSrcRank, 0);
185 auto newInsertOp = rewriter.
create<vector::InsertOp>(
186 loc, newDstType, newSrcVector, newDstVector, newPositions);
198 struct CastAwayTransferReadLeadingOneDim
205 if (read.getTransferRank() == 0)
211 auto shapedType = cast<ShapedType>(read.getSource().getType());
212 if (shapedType.getElementType() != read.getVectorType().getElementType())
215 VectorType oldType = read.getVectorType();
218 if (newType == oldType)
223 oldMap.
getResults().take_back(newType.getRank());
228 ArrayAttr inBoundsAttr;
229 if (read.getInBounds())
231 read.getInBoundsAttr().getValue().take_back(newType.getRank()));
233 auto newRead = rewriter.
create<vector::TransferReadOp>(
234 read.getLoc(), newType, read.getSource(), read.getIndices(),
246 struct CastAwayTransferWriteLeadingOneDim
253 if (write.getTransferRank() == 0)
259 auto shapedType = dyn_cast<ShapedType>(write.getSource().getType());
260 if (shapedType.getElementType() != write.getVectorType().getElementType())
263 VectorType oldType = write.getVectorType();
265 if (newType == oldType)
267 int64_t dropDim = oldType.getRank() - newType.getRank();
271 oldMap.
getResults().take_back(newType.getRank());
276 ArrayAttr inBoundsAttr;
277 if (write.getInBounds())
279 write.getInBoundsAttr().getValue().take_back(newType.getRank()));
281 auto newVector = rewriter.
create<vector::ExtractOp>(
282 write.getLoc(), write.getVector(),
splatZero(dropDim));
284 write, newVector, write.getSource(), write.getIndices(),
296 VectorType oldAccType = dyn_cast<VectorType>(contractOp.getAccType());
297 if (oldAccType ==
nullptr)
299 if (oldAccType.getRank() < 2)
301 if (oldAccType.getShape()[0] != 1)
307 auto oldIndexingMaps = contractOp.getIndexingMapsArray();
310 auto oldIteratorTypes = contractOp.getIteratorTypes();
313 int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0);
320 int64_t currDim = it.index();
321 if (currDim == dimToDrop)
323 newIteratorTypes.push_back(it.value());
327 contractOp.getAcc()};
333 bool validExtract =
false;
335 auto map = it.value();
336 int64_t orginalZeroDim = it.value().getDimPosition(0);
337 if (orginalZeroDim != dimToDrop) {
343 bool tranposeNeeded =
false;
347 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
348 int64_t currDim = map.getDimPosition(i);
349 if (currDim == dimToDrop) {
350 tranposeNeeded =
true;
351 perm.insert(perm.begin(), i);
353 transposeResults.insert(transposeResults.begin(), targetExpr);
357 transposeResults.push_back(targetExpr);
362 if (tranposeNeeded) {
364 contractOp.getContext());
365 operands[it.index()] = rewriter.
create<vector::TransposeOp>(
366 contractOp.getLoc(), operands[it.index()], perm);
373 if (map.getDimPosition(0) == dimToDrop)
376 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
377 int64_t currDim = map.getDimPosition(i);
378 if (currDim == dimToDrop)
382 currDim < dimToDrop ? currDim : currDim - 1);
383 results.push_back(targetExpr);
385 newIndexingMaps.push_back(
AffineMap::get(map.getNumDims() - 1, 0, results,
386 contractOp.getContext()));
389 newOperands.push_back(
390 validExtract ? rewriter.
create<vector::ExtractOp>(contractOp.getLoc(),
391 operands[it.index()],
393 : operands[it.index()]);
395 auto newContractOp = rewriter.
create<vector::ContractionOp>(
396 contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2],
398 rewriter.
getArrayAttr(newIteratorTypes), contractOp.getKind());
400 contractOp, contractOp->getResultTypes()[0], newContractOp);
410 struct CastAwayContractionLeadingOneDim
414 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
434 CastAwayElementwiseLeadingOneDim(
MLIRContext *context,
446 if (newVecType == vecType)
448 int64_t dropDim = vecType.getRank() - newVecType.getRank();
451 if (
auto opVecType = dyn_cast<VectorType>(operand.getType())) {
452 newOperands.push_back(rewriter.
create<vector::ExtractOp>(
455 newOperands.push_back(operand);
460 newOperands, newVecType, op->
getAttrs());
472 .
add<CastAwayExtractStridedSliceLeadingOneDim,
473 CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim,
474 CastAwayTransferReadLeadingOneDim,
475 CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim,
476 CastAwayContractionLeadingOneDim>(patterns.
getContext(), benefit);
static SmallVector< int64_t > splatZero(int64_t rank)
Return a smallVector of size rank containing all zeros.
static VectorType trimLeadingOneDims(VectorType oldType)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePattern is the common base class for all DAG to DAG replacements.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
LogicalResult castAwayContractionLeadingOneDim(vector::ContractionOp contractOp, RewriterBase &rewriter)
Cast away the leading unit dim, if exists, for the given contract op.
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector.shape_cast folding patterns.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of leading one dimension removal patterns.
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...