19 #define DEBUG_TYPE "vector-drop-unit-dim"
33 while (!newShape.empty() && newShape.front() == 1 &&
34 !newScalableDims.front()) {
35 newShape = newShape.drop_front(1);
36 newScalableDims = newScalableDims.drop_front(1);
40 if (newShape.empty()) {
41 newShape = oldShape.take_back();
42 newScalableDims = oldType.getScalableDims().take_back();
44 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
55 struct CastAwayExtractStridedSliceLeadingOneDim
59 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
64 VectorType oldSrcType = extractOp.getSourceVectorType();
67 if (newSrcType.getRank() == oldSrcType.getRank())
70 int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank();
72 VectorType oldDstType = extractOp.getType();
73 VectorType newDstType =
75 oldDstType.getElementType(),
76 oldDstType.getScalableDims().drop_front(dropCount));
80 Value newSrcVector = vector::ExtractOp::create(
81 rewriter, loc, extractOp.getVector(),
splatZero(dropCount));
86 extractOp.getOffsets().getValue().drop_front(dropCount));
88 extractOp.getSizes().getValue().drop_front(dropCount));
90 extractOp.getStrides().getValue().drop_front(dropCount));
92 auto newExtractOp = vector::ExtractStridedSliceOp::create(
93 rewriter, loc, newDstType, newSrcVector, newOffsets, newSizes,
105 struct CastAwayInsertStridedSliceLeadingOneDim
109 LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp,
111 VectorType oldSrcType = insertOp.getSourceVectorType();
113 VectorType oldDstType = insertOp.getDestVectorType();
116 int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank();
117 int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
118 if (srcDropCount == 0 && dstDropCount == 0)
124 Value newSrcVector = vector::ExtractOp::create(
125 rewriter, loc, insertOp.getValueToStore(),
splatZero(srcDropCount));
126 Value newDstVector = vector::ExtractOp::create(
127 rewriter, loc, insertOp.getDest(),
splatZero(dstDropCount));
130 insertOp.getOffsets().getValue().take_back(newDstType.getRank()));
132 insertOp.getStrides().getValue().take_back(newSrcType.getRank()));
134 auto newInsertOp = vector::InsertStridedSliceOp::create(
135 rewriter, loc, newDstType, newSrcVector, newDstVector, newOffsets,
147 struct CastAwayInsertLeadingOneDim :
public OpRewritePattern<vector::InsertOp> {
150 LogicalResult matchAndRewrite(vector::InsertOp insertOp,
152 Type oldSrcType = insertOp.getValueToStoreType();
153 Type newSrcType = oldSrcType;
154 int64_t oldSrcRank = 0, newSrcRank = 0;
155 if (
auto type = dyn_cast<VectorType>(oldSrcType)) {
157 oldSrcRank = type.getRank();
158 newSrcRank = cast<VectorType>(newSrcType).getRank();
161 VectorType oldDstType = insertOp.getDestVectorType();
164 int64_t srcDropCount = oldSrcRank - newSrcRank;
165 int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
166 if (srcDropCount == 0 && dstDropCount == 0)
172 Value newSrcVector = insertOp.getValueToStore();
173 if (oldSrcRank != 0) {
174 newSrcVector = vector::ExtractOp::create(
175 rewriter, loc, insertOp.getValueToStore(),
splatZero(srcDropCount));
177 Value newDstVector = vector::ExtractOp::create(
178 rewriter, loc, insertOp.getDest(),
splatZero(dstDropCount));
184 unsigned oldPosRank = insertOp.getNumIndices();
185 unsigned newPosRank = std::max<int64_t>(0, oldPosRank - dstDropCount);
188 llvm::to_vector(
ArrayRef(oldPosition).take_back(newPosRank));
189 newPosition.resize(newDstType.getRank() - newSrcRank,
192 auto newInsertOp = vector::InsertOp::create(rewriter, loc, newSrcVector,
193 newDstVector, newPosition);
204 VectorType oldMaskType) {
213 int64_t dropDim = oldMaskType.getRank() - newMaskType.getRank();
214 return vector::ExtractOp::create(b, loc, mask,
splatZero(dropDim));
216 return vector::ShapeCastOp::create(b, loc, newMaskType, mask);
222 struct CastAwayTransferReadLeadingOneDim
226 LogicalResult matchAndRewrite(vector::TransferReadOp read,
229 if (cast<MaskableOpInterface>(read.getOperation()).isMasked())
232 if (read.getTransferRank() == 0)
235 auto shapedType = cast<ShapedType>(read.getBase().getType());
236 if (shapedType.getElementType() != read.getVectorType().getElementType())
239 VectorType oldType = read.getVectorType();
242 if (newType == oldType)
247 oldMap.
getResults().take_back(newType.getRank());
252 ArrayAttr inBoundsAttr;
253 if (read.getInBounds())
255 read.getInBoundsAttr().getValue().take_back(newType.getRank()));
258 if (read.getMask()) {
259 VectorType maskType = read.getMaskType();
260 mask = dropUnitDimsFromMask(rewriter, read.getLoc(), read.getMask(),
261 newType, newMap, maskType);
264 auto newRead = vector::TransferReadOp::create(
265 rewriter, read.getLoc(), newType, read.getBase(), read.getIndices(),
276 struct CastAwayTransferWriteLeadingOneDim
280 LogicalResult matchAndRewrite(vector::TransferWriteOp write,
283 if (cast<MaskableOpInterface>(write.getOperation()).isMasked())
286 if (write.getTransferRank() == 0)
289 auto shapedType = dyn_cast<ShapedType>(write.getBase().getType());
290 if (shapedType.getElementType() != write.getVectorType().getElementType())
293 VectorType oldType = write.getVectorType();
295 if (newType == oldType)
297 int64_t dropDim = oldType.getRank() - newType.getRank();
301 oldMap.
getResults().take_back(newType.getRank());
306 ArrayAttr inBoundsAttr;
307 if (write.getInBounds())
309 write.getInBoundsAttr().getValue().take_back(newType.getRank()));
311 auto newVector = vector::ExtractOp::create(
312 rewriter, write.getLoc(), write.getVector(),
splatZero(dropDim));
314 if (write.getMask()) {
315 VectorType maskType = write.getMaskType();
316 Value newMask = dropUnitDimsFromMask(
317 rewriter, write.getLoc(), write.getMask(), newType, newMap, maskType);
319 write, newVector, write.getBase(), write.getIndices(),
325 write, newVector, write.getBase(), write.getIndices(),
335 MaskingOpInterface maskingOp,
337 VectorType oldAccType = dyn_cast<VectorType>(contractOp.getAccType());
338 if (oldAccType ==
nullptr)
340 if (oldAccType.getRank() < 2)
342 if (oldAccType.getShape()[0] != 1)
348 auto oldIndexingMaps = contractOp.getIndexingMapsArray();
351 auto oldIteratorTypes = contractOp.getIteratorTypes();
354 int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0);
361 int64_t currDim = it.index();
362 if (currDim == dimToDrop)
364 newIteratorTypes.push_back(it.value());
368 contractOp.getAcc()};
370 auto loc = contractOp.getLoc();
375 bool validExtract =
false;
377 auto map = it.value();
378 int64_t orginalZeroDim = it.value().getDimPosition(0);
379 if (orginalZeroDim != dimToDrop) {
385 bool transposeNeeded =
false;
389 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
390 int64_t currDim = map.getDimPosition(i);
391 if (currDim == dimToDrop) {
392 transposeNeeded =
true;
393 perm.insert(perm.begin(), i);
395 transposeResults.insert(transposeResults.begin(), targetExpr);
399 transposeResults.push_back(targetExpr);
406 bool transposeNonOuterUnitDims =
false;
407 auto operandShape = cast<ShapedType>(operands[it.index()].getType());
408 for (
auto [index, dim] :
410 if (dim !=
static_cast<int64_t
>(index) &&
411 operandShape.getDimSize(index) != 1) {
412 transposeNonOuterUnitDims =
true;
419 if (transposeNeeded) {
421 contractOp.getContext());
422 if (transposeNonOuterUnitDims) {
423 operands[it.index()] = rewriter.
createOrFold<vector::TransposeOp>(
424 loc, operands[it.index()], perm);
432 if (map.getDimPosition(0) == dimToDrop)
435 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
436 int64_t currDim = map.getDimPosition(i);
437 if (currDim == dimToDrop)
441 currDim < dimToDrop ? currDim : currDim - 1);
442 results.push_back(targetExpr);
444 newIndexingMaps.push_back(
AffineMap::get(map.getNumDims() - 1, 0, results,
445 contractOp.getContext()));
448 newOperands.push_back(validExtract
449 ? vector::ExtractOp::create(rewriter, loc,
450 operands[it.index()],
452 : operands[it.index()]);
457 Operation *newOp = vector::ContractionOp::create(
458 rewriter, loc, newOperands[0], newOperands[1], newOperands[2],
460 rewriter.
getArrayAttr(newIteratorTypes), contractOp.getKind());
463 auto newMask = vector::ExtractOp::create(rewriter, loc, maskingOp.getMask(),
469 return vector::BroadcastOp::create(rewriter, loc,
470 contractOp->getResultTypes()[0],
481 struct CastAwayContractionLeadingOneDim
483 using MaskableOpRewritePattern::MaskableOpRewritePattern;
486 matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
487 MaskingOpInterface maskingOp,
507 CastAwayElementwiseLeadingOneDim(
MLIRContext *context,
511 LogicalResult matchAndRewrite(
Operation *op,
519 if (newVecType == vecType)
521 int64_t dropDim = vecType.getRank() - newVecType.getRank();
524 if (
auto opVecType = dyn_cast<VectorType>(operand.getType())) {
525 newOperands.push_back(vector::ExtractOp::create(
528 newOperands.push_back(operand);
533 newOperands, newVecType, op->
getAttrs());
542 struct CastAwayConstantMaskLeadingOneDim
546 LogicalResult matchAndRewrite(vector::ConstantMaskOp mask,
548 VectorType oldType = mask.getType();
551 if (newType == oldType)
554 int64_t dropDim = oldType.getRank() - newType.getRank();
559 int64_t flatLeadingSize =
560 std::accumulate(dimSizes.begin(), dimSizes.begin() + dropDim + 1,
561 static_cast<int64_t
>(1), std::multiplies<int64_t>());
563 newDimSizes.append(dimSizes.begin() + dropDim + 1, dimSizes.end());
565 auto newMask = vector::ConstantMaskOp::create(rewriter, mask.getLoc(),
566 newType, newDimSizes);
574 void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
577 .add<CastAwayExtractStridedSliceLeadingOneDim,
578 CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim,
579 CastAwayConstantMaskLeadingOneDim, CastAwayTransferReadLeadingOneDim,
580 CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim,
581 CastAwayContractionLeadingOneDim>(
patterns.getContext(), benefit);
static SmallVector< int64_t > splatZero(int64_t rank)
Return a smallVector of size rank containing all zeros.
static VectorType trimLeadingOneDims(VectorType oldType)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
IntegerAttr getI64IntegerAttr(int64_t value)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< Value > castAwayContractionLeadingOneDim(vector::ContractionOp contractOp, MaskingOpInterface maskingOp, RewriterBase &rewriter)
Cast away the leading unit dim, if exists, for the given contract op.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.