20 #define DEBUG_TYPE "vector-drop-unit-dim"
34 while (!newShape.empty() && newShape.front() == 1 &&
35 !newScalableDims.front()) {
36 newShape = newShape.drop_front(1);
37 newScalableDims = newScalableDims.drop_front(1);
41 if (newShape.empty()) {
42 newShape = oldShape.take_back();
43 newScalableDims = oldType.getScalableDims().take_back();
45 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
56 struct CastAwayExtractStridedSliceLeadingOneDim
60 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
65 VectorType oldSrcType = extractOp.getSourceVectorType();
68 if (newSrcType.getRank() == oldSrcType.getRank())
71 int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank();
73 VectorType oldDstType = extractOp.getType();
74 VectorType newDstType =
76 oldDstType.getElementType(),
77 oldDstType.getScalableDims().drop_front(dropCount));
81 Value newSrcVector = rewriter.
create<vector::ExtractOp>(
82 loc, extractOp.getVector(),
splatZero(dropCount));
87 extractOp.getOffsets().getValue().drop_front(dropCount));
89 extractOp.getSizes().getValue().drop_front(dropCount));
91 extractOp.getStrides().getValue().drop_front(dropCount));
93 auto newExtractOp = rewriter.
create<vector::ExtractStridedSliceOp>(
94 loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides);
105 struct CastAwayInsertStridedSliceLeadingOneDim
109 LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp,
111 VectorType oldSrcType = insertOp.getSourceVectorType();
113 VectorType oldDstType = insertOp.getDestVectorType();
116 int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank();
117 int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
118 if (srcDropCount == 0 && dstDropCount == 0)
124 Value newSrcVector = rewriter.
create<vector::ExtractOp>(
125 loc, insertOp.getSource(),
splatZero(srcDropCount));
126 Value newDstVector = rewriter.
create<vector::ExtractOp>(
127 loc, insertOp.getDest(),
splatZero(dstDropCount));
130 insertOp.getOffsets().getValue().take_back(newDstType.getRank()));
132 insertOp.getStrides().getValue().take_back(newSrcType.getRank()));
134 auto newInsertOp = rewriter.
create<vector::InsertStridedSliceOp>(
135 loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides);
146 struct CastAwayInsertLeadingOneDim :
public OpRewritePattern<vector::InsertOp> {
149 LogicalResult matchAndRewrite(vector::InsertOp insertOp,
151 Type oldSrcType = insertOp.getSourceType();
152 Type newSrcType = oldSrcType;
153 int64_t oldSrcRank = 0, newSrcRank = 0;
154 if (
auto type = dyn_cast<VectorType>(oldSrcType)) {
156 oldSrcRank = type.getRank();
157 newSrcRank = cast<VectorType>(newSrcType).getRank();
160 VectorType oldDstType = insertOp.getDestVectorType();
163 int64_t srcDropCount = oldSrcRank - newSrcRank;
164 int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
165 if (srcDropCount == 0 && dstDropCount == 0)
171 Value newSrcVector = insertOp.getSource();
172 if (oldSrcRank != 0) {
173 newSrcVector = rewriter.
create<vector::ExtractOp>(
174 loc, insertOp.getSource(),
splatZero(srcDropCount));
176 Value newDstVector = rewriter.
create<vector::ExtractOp>(
177 loc, insertOp.getDest(),
splatZero(dstDropCount));
183 unsigned oldPosRank = insertOp.getNumIndices();
184 unsigned newPosRank = std::max<int64_t>(0, oldPosRank - dstDropCount);
187 llvm::to_vector(
ArrayRef(oldPosition).take_back(newPosRank));
188 newPosition.resize(newDstType.getRank() - newSrcRank,
191 auto newInsertOp = rewriter.
create<vector::InsertOp>(
192 loc, newSrcVector, newDstVector, newPosition);
203 VectorType oldMaskType) {
212 int64_t dropDim = oldMaskType.getRank() - newMaskType.getRank();
215 return b.
create<vector::ShapeCastOp>(loc, newMaskType, mask);
221 struct CastAwayTransferReadLeadingOneDim
225 LogicalResult matchAndRewrite(vector::TransferReadOp read,
228 if (cast<MaskableOpInterface>(read.getOperation()).isMasked())
231 if (read.getTransferRank() == 0)
234 auto shapedType = cast<ShapedType>(read.getSource().getType());
235 if (shapedType.getElementType() != read.getVectorType().getElementType())
238 VectorType oldType = read.getVectorType();
241 if (newType == oldType)
246 oldMap.
getResults().take_back(newType.getRank());
251 ArrayAttr inBoundsAttr;
252 if (read.getInBounds())
254 read.getInBoundsAttr().getValue().take_back(newType.getRank()));
257 if (read.getMask()) {
258 VectorType maskType = read.getMaskType();
259 mask = dropUnitDimsFromMask(rewriter, read.getLoc(), read.getMask(),
260 newType, newMap, maskType);
263 auto newRead = rewriter.
create<vector::TransferReadOp>(
264 read.getLoc(), newType, read.getSource(), read.getIndices(),
275 struct CastAwayTransferWriteLeadingOneDim
279 LogicalResult matchAndRewrite(vector::TransferWriteOp write,
282 if (cast<MaskableOpInterface>(write.getOperation()).isMasked())
285 if (write.getTransferRank() == 0)
288 auto shapedType = dyn_cast<ShapedType>(write.getSource().getType());
289 if (shapedType.getElementType() != write.getVectorType().getElementType())
292 VectorType oldType = write.getVectorType();
294 if (newType == oldType)
296 int64_t dropDim = oldType.getRank() - newType.getRank();
300 oldMap.
getResults().take_back(newType.getRank());
305 ArrayAttr inBoundsAttr;
306 if (write.getInBounds())
308 write.getInBoundsAttr().getValue().take_back(newType.getRank()));
310 auto newVector = rewriter.
create<vector::ExtractOp>(
311 write.getLoc(), write.getVector(),
splatZero(dropDim));
313 if (write.getMask()) {
314 VectorType maskType = write.getMaskType();
315 Value newMask = dropUnitDimsFromMask(
316 rewriter, write.getLoc(), write.getMask(), newType, newMap, maskType);
318 write, newVector, write.getSource(), write.getIndices(),
324 write, newVector, write.getSource(), write.getIndices(),
334 MaskingOpInterface maskingOp,
336 VectorType oldAccType = dyn_cast<VectorType>(contractOp.getAccType());
337 if (oldAccType ==
nullptr)
339 if (oldAccType.getRank() < 2)
341 if (oldAccType.getShape()[0] != 1)
347 auto oldIndexingMaps = contractOp.getIndexingMapsArray();
350 auto oldIteratorTypes = contractOp.getIteratorTypes();
353 int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0);
360 int64_t currDim = it.index();
361 if (currDim == dimToDrop)
363 newIteratorTypes.push_back(it.value());
367 contractOp.getAcc()};
369 auto loc = contractOp.getLoc();
374 bool validExtract =
false;
376 auto map = it.value();
377 int64_t orginalZeroDim = it.value().getDimPosition(0);
378 if (orginalZeroDim != dimToDrop) {
384 bool tranposeNeeded =
false;
388 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
389 int64_t currDim = map.getDimPosition(i);
390 if (currDim == dimToDrop) {
391 tranposeNeeded =
true;
392 perm.insert(perm.begin(), i);
394 transposeResults.insert(transposeResults.begin(), targetExpr);
398 transposeResults.push_back(targetExpr);
405 bool transposeNonOuterUnitDims =
false;
406 auto operandShape = cast<ShapedType>(operands[it.index()].getType());
407 for (
auto [index, dim] :
409 if (dim !=
static_cast<int64_t
>(index) &&
410 operandShape.getDimSize(index) != 1) {
411 transposeNonOuterUnitDims =
true;
418 if (tranposeNeeded) {
420 contractOp.getContext());
421 if (transposeNonOuterUnitDims) {
422 operands[it.index()] = rewriter.
createOrFold<vector::TransposeOp>(
423 loc, operands[it.index()], perm);
431 if (map.getDimPosition(0) == dimToDrop)
434 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
435 int64_t currDim = map.getDimPosition(i);
436 if (currDim == dimToDrop)
440 currDim < dimToDrop ? currDim : currDim - 1);
441 results.push_back(targetExpr);
443 newIndexingMaps.push_back(
AffineMap::get(map.getNumDims() - 1, 0, results,
444 contractOp.getContext()));
447 newOperands.push_back(
448 validExtract ? rewriter.
create<vector::ExtractOp>(
449 loc, operands[it.index()],
splatZero(dropDim))
450 : operands[it.index()]);
456 loc, newOperands[0], newOperands[1], newOperands[2],
458 rewriter.
getArrayAttr(newIteratorTypes), contractOp.getKind());
461 auto newMask = rewriter.
create<vector::ExtractOp>(loc, maskingOp.getMask(),
468 .
create<vector::BroadcastOp>(loc, contractOp->getResultTypes()[0],
479 struct CastAwayContractionLeadingOneDim
481 using MaskableOpRewritePattern::MaskableOpRewritePattern;
484 matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
485 MaskingOpInterface maskingOp,
505 CastAwayElementwiseLeadingOneDim(
MLIRContext *context,
509 LogicalResult matchAndRewrite(
Operation *op,
517 if (newVecType == vecType)
519 int64_t dropDim = vecType.getRank() - newVecType.getRank();
522 if (
auto opVecType = dyn_cast<VectorType>(operand.getType())) {
523 newOperands.push_back(rewriter.
create<vector::ExtractOp>(
526 newOperands.push_back(operand);
531 newOperands, newVecType, op->
getAttrs());
540 struct CastAwayConstantMaskLeadingOneDim
544 LogicalResult matchAndRewrite(vector::ConstantMaskOp mask,
546 VectorType oldType = mask.getType();
549 if (newType == oldType)
552 int64_t dropDim = oldType.getRank() - newType.getRank();
557 int64_t flatLeadingSize =
558 std::accumulate(dimSizes.begin(), dimSizes.begin() + dropDim + 1,
559 static_cast<int64_t
>(1), std::multiplies<int64_t>());
561 newDimSizes.append(dimSizes.begin() + dropDim + 1, dimSizes.end());
563 auto newMask = rewriter.
create<vector::ConstantMaskOp>(
564 mask.getLoc(), newType, newDimSizes);
575 .add<CastAwayExtractStridedSliceLeadingOneDim,
576 CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim,
577 CastAwayConstantMaskLeadingOneDim, CastAwayTransferReadLeadingOneDim,
578 CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim,
579 CastAwayContractionLeadingOneDim>(
patterns.getContext(), benefit);
static SmallVector< int64_t > splatZero(int64_t rank)
Return a smallVector of size rank containing all zeros.
static VectorType trimLeadingOneDims(VectorType oldType)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
IntegerAttr getI64IntegerAttr(int64_t value)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< Value > castAwayContractionLeadingOneDim(vector::ContractionOp contractOp, MaskingOpInterface maskingOp, RewriterBase &rewriter)
Cast away the leading unit dim, if exists, for the given contract op.
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector.shape_cast folding patterns.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of leading one dimension removal patterns.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.