18#include "llvm/ADT/STLExtras.h"
20#define DEBUG_TYPE "vector-drop-unit-dim"
34 while (!newShape.empty() && newShape.front() == 1 &&
35 !newScalableDims.front()) {
36 newShape = newShape.drop_front(1);
37 newScalableDims = newScalableDims.drop_front(1);
41 if (newShape.empty()) {
42 newShape = oldShape.take_back();
43 newScalableDims = oldType.getScalableDims().take_back();
45 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
56struct CastAwayExtractStridedSliceLeadingOneDim
60 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
61 PatternRewriter &rewriter)
const override {
65 VectorType oldSrcType = extractOp.getSourceVectorType();
68 if (newSrcType.getRank() == oldSrcType.getRank())
71 int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank();
73 VectorType oldDstType = extractOp.getType();
74 VectorType newDstType =
75 VectorType::get(oldDstType.getShape().drop_front(dropCount),
76 oldDstType.getElementType(),
77 oldDstType.getScalableDims().drop_front(dropCount));
79 Location loc = extractOp.getLoc();
81 Value newSrcVector = vector::ExtractOp::create(
82 rewriter, loc, extractOp.getSource(),
splatZero(dropCount));
87 extractOp.getOffsets().getValue().drop_front(dropCount));
89 extractOp.getSizes().getValue().drop_front(dropCount));
91 extractOp.getStrides().getValue().drop_front(dropCount));
93 auto newExtractOp = vector::ExtractStridedSliceOp::create(
94 rewriter, loc, newDstType, newSrcVector, newOffsets, newSizes,
106struct CastAwayInsertStridedSliceLeadingOneDim
110 LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp,
111 PatternRewriter &rewriter)
const override {
112 VectorType oldSrcType = insertOp.getSourceVectorType();
114 VectorType oldDstType = insertOp.getDestVectorType();
117 int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank();
118 int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
119 if (srcDropCount == 0 && dstDropCount == 0)
123 Location loc = insertOp.getLoc();
125 Value newSrcVector = vector::ExtractOp::create(
126 rewriter, loc, insertOp.getValueToStore(),
splatZero(srcDropCount));
127 Value newDstVector = vector::ExtractOp::create(
128 rewriter, loc, insertOp.getDest(),
splatZero(dstDropCount));
131 insertOp.getOffsets().getValue().take_back(newDstType.getRank()));
133 insertOp.getStrides().getValue().take_back(newSrcType.getRank()));
135 auto newInsertOp = vector::InsertStridedSliceOp::create(
136 rewriter, loc, newDstType, newSrcVector, newDstVector, newOffsets,
148struct CastAwayInsertLeadingOneDim :
public OpRewritePattern<vector::InsertOp> {
151 LogicalResult matchAndRewrite(vector::InsertOp insertOp,
152 PatternRewriter &rewriter)
const override {
153 Type oldSrcType = insertOp.getValueToStoreType();
154 Type newSrcType = oldSrcType;
155 int64_t oldSrcRank = 0, newSrcRank = 0;
156 if (
auto type = dyn_cast<VectorType>(oldSrcType)) {
158 oldSrcRank = type.getRank();
159 newSrcRank = cast<VectorType>(newSrcType).getRank();
162 VectorType oldDstType = insertOp.getDestVectorType();
165 int64_t srcDropCount = oldSrcRank - newSrcRank;
166 int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
167 if (srcDropCount == 0 && dstDropCount == 0)
171 Location loc = insertOp.getLoc();
173 Value newSrcVector = insertOp.getValueToStore();
174 if (oldSrcRank != 0) {
175 newSrcVector = vector::ExtractOp::create(
176 rewriter, loc, insertOp.getValueToStore(),
splatZero(srcDropCount));
178 Value newDstVector = vector::ExtractOp::create(
179 rewriter, loc, insertOp.getDest(),
splatZero(dstDropCount));
185 unsigned oldPosRank = insertOp.getNumIndices();
186 unsigned newPosRank = std::max<int64_t>(0, oldPosRank - dstDropCount);
187 SmallVector<OpFoldResult> oldPosition = insertOp.getMixedPosition();
188 SmallVector<OpFoldResult> newPosition =
189 llvm::to_vector(ArrayRef(oldPosition).take_back(newPosRank));
190 newPosition.resize(newDstType.getRank() - newSrcRank,
193 auto newInsertOp = vector::InsertOp::create(rewriter, loc, newSrcVector,
194 newDstVector, newPosition);
205 VectorType oldMaskType) {
214 int64_t dropDim = oldMaskType.getRank() - newMaskType.getRank();
215 return vector::ExtractOp::create(
b, loc, mask,
splatZero(dropDim));
217 return vector::ShapeCastOp::create(
b, loc, newMaskType, mask);
223struct CastAwayTransferReadLeadingOneDim
227 LogicalResult matchAndRewrite(vector::TransferReadOp read,
228 PatternRewriter &rewriter)
const override {
230 if (cast<MaskableOpInterface>(read.getOperation()).isMasked())
233 if (read.getTransferRank() == 0)
236 auto shapedType = cast<ShapedType>(read.getBase().getType());
237 if (shapedType.getElementType() != read.getVectorType().getElementType())
240 VectorType oldType = read.getVectorType();
243 if (newType == oldType)
246 AffineMap oldMap = read.getPermutationMap();
247 ArrayRef<AffineExpr> newResults =
248 oldMap.
getResults().take_back(newType.getRank());
254 if (read.getInBounds())
256 read.getInBoundsAttr().getValue().take_back(newType.getRank()));
258 Value mask = Value();
259 if (read.getMask()) {
260 VectorType maskType = read.getMaskType();
261 mask = dropUnitDimsFromMask(rewriter, read.getLoc(), read.getMask(),
262 newType, newMap, maskType);
265 auto newRead = vector::TransferReadOp::create(
266 rewriter, read.getLoc(), newType, read.getBase(), read.getIndices(),
267 AffineMapAttr::get(newMap), read.getPadding(), mask, inBoundsAttr);
277struct CastAwayTransferWriteLeadingOneDim
281 LogicalResult matchAndRewrite(vector::TransferWriteOp write,
282 PatternRewriter &rewriter)
const override {
284 if (cast<MaskableOpInterface>(write.getOperation()).isMasked())
287 if (write.getTransferRank() == 0)
290 auto shapedType = dyn_cast<ShapedType>(write.getBase().getType());
291 if (shapedType.getElementType() != write.getVectorType().getElementType())
294 VectorType oldType = write.getVectorType();
296 if (newType == oldType)
298 int64_t dropDim = oldType.getRank() - newType.getRank();
301 ArrayRef<AffineExpr> newResults =
302 oldMap.
getResults().take_back(newType.getRank());
308 if (write.getInBounds())
310 write.getInBoundsAttr().getValue().take_back(newType.getRank()));
312 auto newVector = vector::ExtractOp::create(
313 rewriter, write.getLoc(), write.getVector(),
splatZero(dropDim));
315 if (write.getMask()) {
316 VectorType maskType = write.getMaskType();
317 Value newMask = dropUnitDimsFromMask(
318 rewriter, write.getLoc(), write.getMask(), newType, newMap, maskType);
320 write, newVector, write.getBase(), write.getIndices(),
321 AffineMapAttr::get(newMap), newMask, inBoundsAttr);
326 write, newVector, write.getBase(), write.getIndices(),
327 AffineMapAttr::get(newMap), inBoundsAttr);
335mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
336 MaskingOpInterface maskingOp,
338 VectorType oldAccType = dyn_cast<VectorType>(contractOp.getAccType());
339 if (oldAccType ==
nullptr)
341 if (oldAccType.getRank() < 2)
343 if (oldAccType.getShape()[0] != 1)
349 auto oldIndexingMaps = contractOp.getIndexingMapsArray();
352 auto oldIteratorTypes = contractOp.getIteratorTypes();
355 int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0);
361 for (
const auto &it : llvm::enumerate(oldIteratorTypes)) {
363 if (currDim == dimToDrop)
365 newIteratorTypes.push_back(it.value());
369 contractOp.getAcc()};
371 auto loc = contractOp.getLoc();
373 for (
const auto &it : llvm::enumerate(oldIndexingMaps)) {
376 bool validExtract =
false;
378 auto map = it.value();
379 int64_t orginalZeroDim = it.value().getDimPosition(0);
380 if (orginalZeroDim != dimToDrop) {
386 bool transposeNeeded =
false;
390 for (
int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
391 int64_t currDim = map.getDimPosition(i);
392 if (currDim == dimToDrop) {
393 transposeNeeded =
true;
394 perm.insert(perm.begin(), i);
396 transposeResults.insert(transposeResults.begin(), targetExpr);
400 transposeResults.push_back(targetExpr);
407 bool transposeNonOuterUnitDims =
false;
408 auto operandShape = cast<ShapedType>(operands[it.index()].getType());
409 for (
auto [
index, dim] :
412 operandShape.getDimSize(
index) != 1) {
413 transposeNonOuterUnitDims =
true;
420 if (transposeNeeded) {
422 contractOp.getContext());
423 if (transposeNonOuterUnitDims) {
424 operands[it.index()] = rewriter.
createOrFold<vector::TransposeOp>(
425 loc, operands[it.index()], perm);
433 if (map.getDimPosition(0) == dimToDrop)
436 for (
int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
437 int64_t currDim = map.getDimPosition(i);
438 if (currDim == dimToDrop)
442 currDim < dimToDrop ? currDim : currDim - 1);
443 results.push_back(targetExpr);
445 newIndexingMaps.push_back(
AffineMap::get(map.getNumDims() - 1, 0, results,
446 contractOp.getContext()));
449 newOperands.push_back(validExtract
450 ? vector::ExtractOp::create(rewriter, loc,
451 operands[it.index()],
453 : operands[it.index()]);
458 Operation *newOp = vector::ContractionOp::create(
459 rewriter, loc, newOperands[0], newOperands[1], newOperands[2],
461 rewriter.
getArrayAttr(newIteratorTypes), contractOp.getKind());
464 auto newMask = vector::ExtractOp::create(rewriter, loc, maskingOp.getMask(),
470 return vector::BroadcastOp::create(rewriter, loc,
471 contractOp->getResultTypes()[0],
482struct CastAwayContractionLeadingOneDim
484 using MaskableOpRewritePattern::MaskableOpRewritePattern;
487 matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
488 MaskingOpInterface maskingOp,
489 PatternRewriter &rewriter)
const override {
490 return castAwayContractionLeadingOneDim(contractOp, maskingOp, rewriter);
508 CastAwayElementwiseLeadingOneDim(MLIRContext *context,
509 PatternBenefit benefit = 1)
510 : RewritePattern(MatchAnyOpTypeTag(), benefit, context) {}
512 LogicalResult matchAndRewrite(Operation *op,
513 PatternRewriter &rewriter)
const override {
520 if (newVecType == vecType)
522 int64_t dropDim = vecType.getRank() - newVecType.getRank();
523 SmallVector<Value, 4> newOperands;
525 if (
auto opVecType = dyn_cast<VectorType>(operand.getType())) {
526 newOperands.push_back(vector::ExtractOp::create(
529 newOperands.push_back(operand);
534 newOperands, newVecType, op->
getAttrs());
543struct CastAwayConstantMaskLeadingOneDim
547 LogicalResult matchAndRewrite(vector::ConstantMaskOp mask,
548 PatternRewriter &rewriter)
const override {
549 VectorType oldType = mask.getType();
552 if (newType == oldType)
555 int64_t dropDim = oldType.getRank() - newType.getRank();
556 ArrayRef<int64_t> dimSizes = mask.getMaskDimSizes();
560 int64_t flatLeadingSize =
561 llvm::product_of(dimSizes.take_front(dropDim + 1));
562 SmallVector<int64_t> newDimSizes = {flatLeadingSize};
563 newDimSizes.append(dimSizes.begin() + dropDim + 1, dimSizes.end());
565 auto newMask = vector::ConstantMaskOp::create(rewriter, mask.getLoc(),
566 newType, newDimSizes);
574void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
577 .add<CastAwayExtractStridedSliceLeadingOneDim,
578 CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim,
579 CastAwayConstantMaskLeadingOneDim, CastAwayTransferReadLeadingOneDim,
580 CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim,
581 CastAwayContractionLeadingOneDim>(
patterns.getContext(), benefit);
static SmallVector< int64_t > splatZero(int64_t rank)
Return a smallVector of size rank containing all zeros.
static VectorType trimLeadingOneDims(VectorType oldType)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
IntegerAttr getI64IntegerAttr(int64_t value)
AffineExpr getAffineDimExpr(unsigned position)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
RewritePattern is the common base class for all DAG to DAG replacements.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.