18 #include "llvm/ADT/STLExtras.h"
20 #define DEBUG_TYPE "vector-drop-unit-dim"
34 while (!newShape.empty() && newShape.front() == 1 &&
35 !newScalableDims.front()) {
36 newShape = newShape.drop_front(1);
37 newScalableDims = newScalableDims.drop_front(1);
41 if (newShape.empty()) {
42 newShape = oldShape.take_back();
43 newScalableDims = oldType.getScalableDims().take_back();
45 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
56 struct CastAwayExtractStridedSliceLeadingOneDim
60 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
65 VectorType oldSrcType = extractOp.getSourceVectorType();
68 if (newSrcType.getRank() == oldSrcType.getRank())
71 int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank();
73 VectorType oldDstType = extractOp.getType();
74 VectorType newDstType =
76 oldDstType.getElementType(),
77 oldDstType.getScalableDims().drop_front(dropCount));
81 Value newSrcVector = vector::ExtractOp::create(
82 rewriter, loc, extractOp.getSource(),
splatZero(dropCount));
87 extractOp.getOffsets().getValue().drop_front(dropCount));
89 extractOp.getSizes().getValue().drop_front(dropCount));
91 extractOp.getStrides().getValue().drop_front(dropCount));
93 auto newExtractOp = vector::ExtractStridedSliceOp::create(
94 rewriter, loc, newDstType, newSrcVector, newOffsets, newSizes,
106 struct CastAwayInsertStridedSliceLeadingOneDim
110 LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp,
112 VectorType oldSrcType = insertOp.getSourceVectorType();
114 VectorType oldDstType = insertOp.getDestVectorType();
117 int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank();
118 int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
119 if (srcDropCount == 0 && dstDropCount == 0)
125 Value newSrcVector = vector::ExtractOp::create(
126 rewriter, loc, insertOp.getValueToStore(),
splatZero(srcDropCount));
127 Value newDstVector = vector::ExtractOp::create(
128 rewriter, loc, insertOp.getDest(),
splatZero(dstDropCount));
131 insertOp.getOffsets().getValue().take_back(newDstType.getRank()));
133 insertOp.getStrides().getValue().take_back(newSrcType.getRank()));
135 auto newInsertOp = vector::InsertStridedSliceOp::create(
136 rewriter, loc, newDstType, newSrcVector, newDstVector, newOffsets,
148 struct CastAwayInsertLeadingOneDim :
public OpRewritePattern<vector::InsertOp> {
151 LogicalResult matchAndRewrite(vector::InsertOp insertOp,
153 Type oldSrcType = insertOp.getValueToStoreType();
154 Type newSrcType = oldSrcType;
155 int64_t oldSrcRank = 0, newSrcRank = 0;
156 if (
auto type = dyn_cast<VectorType>(oldSrcType)) {
158 oldSrcRank = type.getRank();
159 newSrcRank = cast<VectorType>(newSrcType).getRank();
162 VectorType oldDstType = insertOp.getDestVectorType();
165 int64_t srcDropCount = oldSrcRank - newSrcRank;
166 int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
167 if (srcDropCount == 0 && dstDropCount == 0)
173 Value newSrcVector = insertOp.getValueToStore();
174 if (oldSrcRank != 0) {
175 newSrcVector = vector::ExtractOp::create(
176 rewriter, loc, insertOp.getValueToStore(),
splatZero(srcDropCount));
178 Value newDstVector = vector::ExtractOp::create(
179 rewriter, loc, insertOp.getDest(),
splatZero(dstDropCount));
185 unsigned oldPosRank = insertOp.getNumIndices();
186 unsigned newPosRank = std::max<int64_t>(0, oldPosRank - dstDropCount);
189 llvm::to_vector(
ArrayRef(oldPosition).take_back(newPosRank));
190 newPosition.resize(newDstType.getRank() - newSrcRank,
193 auto newInsertOp = vector::InsertOp::create(rewriter, loc, newSrcVector,
194 newDstVector, newPosition);
205 VectorType oldMaskType) {
214 int64_t dropDim = oldMaskType.getRank() - newMaskType.getRank();
215 return vector::ExtractOp::create(b, loc, mask,
splatZero(dropDim));
217 return vector::ShapeCastOp::create(b, loc, newMaskType, mask);
223 struct CastAwayTransferReadLeadingOneDim
227 LogicalResult matchAndRewrite(vector::TransferReadOp read,
230 if (cast<MaskableOpInterface>(read.getOperation()).isMasked())
233 if (read.getTransferRank() == 0)
236 auto shapedType = cast<ShapedType>(read.getBase().getType());
237 if (shapedType.getElementType() != read.getVectorType().getElementType())
240 VectorType oldType = read.getVectorType();
243 if (newType == oldType)
248 oldMap.
getResults().take_back(newType.getRank());
253 ArrayAttr inBoundsAttr;
254 if (read.getInBounds())
256 read.getInBoundsAttr().getValue().take_back(newType.getRank()));
259 if (read.getMask()) {
260 VectorType maskType = read.getMaskType();
261 mask = dropUnitDimsFromMask(rewriter, read.getLoc(), read.getMask(),
262 newType, newMap, maskType);
265 auto newRead = vector::TransferReadOp::create(
266 rewriter, read.getLoc(), newType, read.getBase(), read.getIndices(),
277 struct CastAwayTransferWriteLeadingOneDim
281 LogicalResult matchAndRewrite(vector::TransferWriteOp write,
284 if (cast<MaskableOpInterface>(write.getOperation()).isMasked())
287 if (write.getTransferRank() == 0)
290 auto shapedType = dyn_cast<ShapedType>(write.getBase().getType());
291 if (shapedType.getElementType() != write.getVectorType().getElementType())
294 VectorType oldType = write.getVectorType();
296 if (newType == oldType)
298 int64_t dropDim = oldType.getRank() - newType.getRank();
302 oldMap.
getResults().take_back(newType.getRank());
307 ArrayAttr inBoundsAttr;
308 if (write.getInBounds())
310 write.getInBoundsAttr().getValue().take_back(newType.getRank()));
312 auto newVector = vector::ExtractOp::create(
313 rewriter, write.getLoc(), write.getVector(),
splatZero(dropDim));
315 if (write.getMask()) {
316 VectorType maskType = write.getMaskType();
317 Value newMask = dropUnitDimsFromMask(
318 rewriter, write.getLoc(), write.getMask(), newType, newMap, maskType);
320 write, newVector, write.getBase(), write.getIndices(),
326 write, newVector, write.getBase(), write.getIndices(),
336 MaskingOpInterface maskingOp,
338 VectorType oldAccType = dyn_cast<VectorType>(contractOp.getAccType());
339 if (oldAccType ==
nullptr)
341 if (oldAccType.getRank() < 2)
343 if (oldAccType.getShape()[0] != 1)
349 auto oldIndexingMaps = contractOp.getIndexingMapsArray();
352 auto oldIteratorTypes = contractOp.getIteratorTypes();
355 int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0);
362 int64_t currDim = it.index();
363 if (currDim == dimToDrop)
365 newIteratorTypes.push_back(it.value());
369 contractOp.getAcc()};
371 auto loc = contractOp.getLoc();
376 bool validExtract =
false;
378 auto map = it.value();
379 int64_t orginalZeroDim = it.value().getDimPosition(0);
380 if (orginalZeroDim != dimToDrop) {
386 bool transposeNeeded =
false;
390 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
391 int64_t currDim = map.getDimPosition(i);
392 if (currDim == dimToDrop) {
393 transposeNeeded =
true;
394 perm.insert(perm.begin(), i);
396 transposeResults.insert(transposeResults.begin(), targetExpr);
400 transposeResults.push_back(targetExpr);
407 bool transposeNonOuterUnitDims =
false;
408 auto operandShape = cast<ShapedType>(operands[it.index()].getType());
409 for (
auto [index, dim] :
411 if (dim !=
static_cast<int64_t
>(index) &&
412 operandShape.getDimSize(index) != 1) {
413 transposeNonOuterUnitDims =
true;
420 if (transposeNeeded) {
422 contractOp.getContext());
423 if (transposeNonOuterUnitDims) {
424 operands[it.index()] = rewriter.
createOrFold<vector::TransposeOp>(
425 loc, operands[it.index()], perm);
433 if (map.getDimPosition(0) == dimToDrop)
436 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
437 int64_t currDim = map.getDimPosition(i);
438 if (currDim == dimToDrop)
442 currDim < dimToDrop ? currDim : currDim - 1);
443 results.push_back(targetExpr);
445 newIndexingMaps.push_back(
AffineMap::get(map.getNumDims() - 1, 0, results,
446 contractOp.getContext()));
449 newOperands.push_back(validExtract
450 ? vector::ExtractOp::create(rewriter, loc,
451 operands[it.index()],
453 : operands[it.index()]);
458 Operation *newOp = vector::ContractionOp::create(
459 rewriter, loc, newOperands[0], newOperands[1], newOperands[2],
461 rewriter.
getArrayAttr(newIteratorTypes), contractOp.getKind());
464 auto newMask = vector::ExtractOp::create(rewriter, loc, maskingOp.getMask(),
470 return vector::BroadcastOp::create(rewriter, loc,
471 contractOp->getResultTypes()[0],
482 struct CastAwayContractionLeadingOneDim
484 using MaskableOpRewritePattern::MaskableOpRewritePattern;
487 matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
488 MaskingOpInterface maskingOp,
508 CastAwayElementwiseLeadingOneDim(
MLIRContext *context,
512 LogicalResult matchAndRewrite(
Operation *op,
520 if (newVecType == vecType)
522 int64_t dropDim = vecType.getRank() - newVecType.getRank();
525 if (
auto opVecType = dyn_cast<VectorType>(operand.getType())) {
526 newOperands.push_back(vector::ExtractOp::create(
529 newOperands.push_back(operand);
534 newOperands, newVecType, op->
getAttrs());
543 struct CastAwayConstantMaskLeadingOneDim
547 LogicalResult matchAndRewrite(vector::ConstantMaskOp mask,
549 VectorType oldType = mask.getType();
552 if (newType == oldType)
555 int64_t dropDim = oldType.getRank() - newType.getRank();
560 int64_t flatLeadingSize =
561 llvm::product_of(dimSizes.take_front(dropDim + 1));
563 newDimSizes.append(dimSizes.begin() + dropDim + 1, dimSizes.end());
565 auto newMask = vector::ConstantMaskOp::create(rewriter, mask.getLoc(),
566 newType, newDimSizes);
574 void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
577 .add<CastAwayExtractStridedSliceLeadingOneDim,
578 CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim,
579 CastAwayConstantMaskLeadingOneDim, CastAwayTransferReadLeadingOneDim,
580 CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim,
581 CastAwayContractionLeadingOneDim>(
patterns.getContext(), benefit);
static SmallVector< int64_t > splatZero(int64_t rank)
Return a smallVector of size rank containing all zeros.
static VectorType trimLeadingOneDims(VectorType oldType)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
IntegerAttr getI64IntegerAttr(int64_t value)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< Value > castAwayContractionLeadingOneDim(vector::ContractionOp contractOp, MaskingOpInterface maskingOp, RewriterBase &rewriter)
Cast away the leading unit dim, if exists, for the given contract op.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.