19 #define DEBUG_TYPE "vector-drop-unit-dim"
33 while (!newShape.empty() && newShape.front() == 1 &&
34 !newScalableDims.front()) {
35 newShape = newShape.drop_front(1);
36 newScalableDims = newScalableDims.drop_front(1);
40 if (newShape.empty()) {
41 newShape = oldShape.take_back();
42 newScalableDims = oldType.getScalableDims().take_back();
44 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
55 struct CastAwayExtractStridedSliceLeadingOneDim
59 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
64 VectorType oldSrcType = extractOp.getSourceVectorType();
67 if (newSrcType.getRank() == oldSrcType.getRank())
70 int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank();
72 VectorType oldDstType = extractOp.getType();
73 VectorType newDstType =
75 oldDstType.getElementType(),
76 oldDstType.getScalableDims().drop_front(dropCount));
80 Value newSrcVector = rewriter.
create<vector::ExtractOp>(
81 loc, extractOp.getVector(),
splatZero(dropCount));
86 extractOp.getOffsets().getValue().drop_front(dropCount));
88 extractOp.getSizes().getValue().drop_front(dropCount));
90 extractOp.getStrides().getValue().drop_front(dropCount));
92 auto newExtractOp = rewriter.
create<vector::ExtractStridedSliceOp>(
93 loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides);
104 struct CastAwayInsertStridedSliceLeadingOneDim
108 LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp,
110 VectorType oldSrcType = insertOp.getSourceVectorType();
112 VectorType oldDstType = insertOp.getDestVectorType();
115 int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank();
116 int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
117 if (srcDropCount == 0 && dstDropCount == 0)
123 Value newSrcVector = rewriter.
create<vector::ExtractOp>(
124 loc, insertOp.getValueToStore(),
splatZero(srcDropCount));
125 Value newDstVector = rewriter.
create<vector::ExtractOp>(
126 loc, insertOp.getDest(),
splatZero(dstDropCount));
129 insertOp.getOffsets().getValue().take_back(newDstType.getRank()));
131 insertOp.getStrides().getValue().take_back(newSrcType.getRank()));
133 auto newInsertOp = rewriter.
create<vector::InsertStridedSliceOp>(
134 loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides);
145 struct CastAwayInsertLeadingOneDim :
public OpRewritePattern<vector::InsertOp> {
148 LogicalResult matchAndRewrite(vector::InsertOp insertOp,
150 Type oldSrcType = insertOp.getValueToStoreType();
151 Type newSrcType = oldSrcType;
152 int64_t oldSrcRank = 0, newSrcRank = 0;
153 if (
auto type = dyn_cast<VectorType>(oldSrcType)) {
155 oldSrcRank = type.getRank();
156 newSrcRank = cast<VectorType>(newSrcType).getRank();
159 VectorType oldDstType = insertOp.getDestVectorType();
162 int64_t srcDropCount = oldSrcRank - newSrcRank;
163 int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
164 if (srcDropCount == 0 && dstDropCount == 0)
170 Value newSrcVector = insertOp.getValueToStore();
171 if (oldSrcRank != 0) {
172 newSrcVector = rewriter.
create<vector::ExtractOp>(
173 loc, insertOp.getValueToStore(),
splatZero(srcDropCount));
175 Value newDstVector = rewriter.
create<vector::ExtractOp>(
176 loc, insertOp.getDest(),
splatZero(dstDropCount));
182 unsigned oldPosRank = insertOp.getNumIndices();
183 unsigned newPosRank = std::max<int64_t>(0, oldPosRank - dstDropCount);
186 llvm::to_vector(
ArrayRef(oldPosition).take_back(newPosRank));
187 newPosition.resize(newDstType.getRank() - newSrcRank,
190 auto newInsertOp = rewriter.
create<vector::InsertOp>(
191 loc, newSrcVector, newDstVector, newPosition);
202 VectorType oldMaskType) {
211 int64_t dropDim = oldMaskType.getRank() - newMaskType.getRank();
214 return b.
create<vector::ShapeCastOp>(loc, newMaskType, mask);
220 struct CastAwayTransferReadLeadingOneDim
224 LogicalResult matchAndRewrite(vector::TransferReadOp read,
227 if (cast<MaskableOpInterface>(read.getOperation()).isMasked())
230 if (read.getTransferRank() == 0)
233 auto shapedType = cast<ShapedType>(read.getSource().getType());
234 if (shapedType.getElementType() != read.getVectorType().getElementType())
237 VectorType oldType = read.getVectorType();
240 if (newType == oldType)
245 oldMap.
getResults().take_back(newType.getRank());
250 ArrayAttr inBoundsAttr;
251 if (read.getInBounds())
253 read.getInBoundsAttr().getValue().take_back(newType.getRank()));
256 if (read.getMask()) {
257 VectorType maskType = read.getMaskType();
258 mask = dropUnitDimsFromMask(rewriter, read.getLoc(), read.getMask(),
259 newType, newMap, maskType);
262 auto newRead = rewriter.
create<vector::TransferReadOp>(
263 read.getLoc(), newType, read.getSource(), read.getIndices(),
274 struct CastAwayTransferWriteLeadingOneDim
278 LogicalResult matchAndRewrite(vector::TransferWriteOp write,
281 if (cast<MaskableOpInterface>(write.getOperation()).isMasked())
284 if (write.getTransferRank() == 0)
287 auto shapedType = dyn_cast<ShapedType>(write.getSource().getType());
288 if (shapedType.getElementType() != write.getVectorType().getElementType())
291 VectorType oldType = write.getVectorType();
293 if (newType == oldType)
295 int64_t dropDim = oldType.getRank() - newType.getRank();
299 oldMap.
getResults().take_back(newType.getRank());
304 ArrayAttr inBoundsAttr;
305 if (write.getInBounds())
307 write.getInBoundsAttr().getValue().take_back(newType.getRank()));
309 auto newVector = rewriter.
create<vector::ExtractOp>(
310 write.getLoc(), write.getVector(),
splatZero(dropDim));
312 if (write.getMask()) {
313 VectorType maskType = write.getMaskType();
314 Value newMask = dropUnitDimsFromMask(
315 rewriter, write.getLoc(), write.getMask(), newType, newMap, maskType);
317 write, newVector, write.getSource(), write.getIndices(),
323 write, newVector, write.getSource(), write.getIndices(),
333 MaskingOpInterface maskingOp,
335 VectorType oldAccType = dyn_cast<VectorType>(contractOp.getAccType());
336 if (oldAccType ==
nullptr)
338 if (oldAccType.getRank() < 2)
340 if (oldAccType.getShape()[0] != 1)
346 auto oldIndexingMaps = contractOp.getIndexingMapsArray();
349 auto oldIteratorTypes = contractOp.getIteratorTypes();
352 int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0);
359 int64_t currDim = it.index();
360 if (currDim == dimToDrop)
362 newIteratorTypes.push_back(it.value());
366 contractOp.getAcc()};
368 auto loc = contractOp.getLoc();
373 bool validExtract =
false;
375 auto map = it.value();
376 int64_t orginalZeroDim = it.value().getDimPosition(0);
377 if (orginalZeroDim != dimToDrop) {
383 bool transposeNeeded =
false;
387 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
388 int64_t currDim = map.getDimPosition(i);
389 if (currDim == dimToDrop) {
390 transposeNeeded =
true;
391 perm.insert(perm.begin(), i);
393 transposeResults.insert(transposeResults.begin(), targetExpr);
397 transposeResults.push_back(targetExpr);
404 bool transposeNonOuterUnitDims =
false;
405 auto operandShape = cast<ShapedType>(operands[it.index()].getType());
406 for (
auto [index, dim] :
408 if (dim !=
static_cast<int64_t
>(index) &&
409 operandShape.getDimSize(index) != 1) {
410 transposeNonOuterUnitDims =
true;
417 if (transposeNeeded) {
419 contractOp.getContext());
420 if (transposeNonOuterUnitDims) {
421 operands[it.index()] = rewriter.
createOrFold<vector::TransposeOp>(
422 loc, operands[it.index()], perm);
430 if (map.getDimPosition(0) == dimToDrop)
433 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
434 int64_t currDim = map.getDimPosition(i);
435 if (currDim == dimToDrop)
439 currDim < dimToDrop ? currDim : currDim - 1);
440 results.push_back(targetExpr);
442 newIndexingMaps.push_back(
AffineMap::get(map.getNumDims() - 1, 0, results,
443 contractOp.getContext()));
446 newOperands.push_back(
447 validExtract ? rewriter.
create<vector::ExtractOp>(
448 loc, operands[it.index()],
splatZero(dropDim))
449 : operands[it.index()]);
455 loc, newOperands[0], newOperands[1], newOperands[2],
457 rewriter.
getArrayAttr(newIteratorTypes), contractOp.getKind());
460 auto newMask = rewriter.
create<vector::ExtractOp>(loc, maskingOp.getMask(),
467 .
create<vector::BroadcastOp>(loc, contractOp->getResultTypes()[0],
478 struct CastAwayContractionLeadingOneDim
480 using MaskableOpRewritePattern::MaskableOpRewritePattern;
483 matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
484 MaskingOpInterface maskingOp,
504 CastAwayElementwiseLeadingOneDim(
MLIRContext *context,
508 LogicalResult matchAndRewrite(
Operation *op,
516 if (newVecType == vecType)
518 int64_t dropDim = vecType.getRank() - newVecType.getRank();
521 if (
auto opVecType = dyn_cast<VectorType>(operand.getType())) {
522 newOperands.push_back(rewriter.
create<vector::ExtractOp>(
525 newOperands.push_back(operand);
530 newOperands, newVecType, op->
getAttrs());
539 struct CastAwayConstantMaskLeadingOneDim
543 LogicalResult matchAndRewrite(vector::ConstantMaskOp mask,
545 VectorType oldType = mask.getType();
548 if (newType == oldType)
551 int64_t dropDim = oldType.getRank() - newType.getRank();
556 int64_t flatLeadingSize =
557 std::accumulate(dimSizes.begin(), dimSizes.begin() + dropDim + 1,
558 static_cast<int64_t
>(1), std::multiplies<int64_t>());
560 newDimSizes.append(dimSizes.begin() + dropDim + 1, dimSizes.end());
562 auto newMask = rewriter.
create<vector::ConstantMaskOp>(
563 mask.getLoc(), newType, newDimSizes);
574 .add<CastAwayExtractStridedSliceLeadingOneDim,
575 CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim,
576 CastAwayConstantMaskLeadingOneDim, CastAwayTransferReadLeadingOneDim,
577 CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim,
578 CastAwayContractionLeadingOneDim>(
patterns.getContext(), benefit);
static SmallVector< int64_t > splatZero(int64_t rank)
Return a smallVector of size rank containing all zeros.
static VectorType trimLeadingOneDims(VectorType oldType)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
IntegerAttr getI64IntegerAttr(int64_t value)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< Value > castAwayContractionLeadingOneDim(vector::ContractionOp contractOp, MaskingOpInterface maskingOp, RewriterBase &rewriter)
Cast away the leading unit dim, if exists, for the given contract op.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of leading one dimension removal patterns.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.