18 #include "llvm/ADT/MapVector.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/Support/DebugLog.h"
21 #include "llvm/Support/InterleavedRange.h"
24 #define DEBUG_TYPE "vector-unroll"
37 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr))
38 return constExpr.getValue() == 0;
44 if (isBroadcast(dim.value()))
46 unsigned pos = cast<AffineDimExpr>(dim.value()).getPosition();
51 affine::AffineApplyOp::create(builder, loc, map, indices[pos]);
63 assert(offsets.size() <= originalIndices.size() &&
64 "Offsets should not exceed the number of original indices");
67 auto start = indices.size() - offsets.size();
70 indices[start + i] = arith::AddIOp::create(
71 rewriter, loc, originalIndices[start + i],
90 static std::optional<SmallVector<int64_t>>
94 LDBG() <<
"--no filter constraint -> BAIL";
98 "vector unrolling expects the native shape or native"
99 "shape call back function to be set");
100 auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
101 if (!unrollableVectorOp) {
102 LDBG() <<
"--not an unrollable op -> BAIL";
105 auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
106 if (!maybeUnrollShape) {
107 LDBG() <<
"--could not get shape of op " << *op <<
" -> BAIL";
110 LDBG() <<
"--vector op shape: " << llvm::interleaved(*maybeUnrollShape);
112 std::optional<SmallVector<int64_t>> targetShape =
options.nativeShape(op);
114 LDBG() <<
"--no unrolling target shape defined " << *op <<
"-> SKIP";
117 LDBG() <<
"--target shape: " << llvm::interleaved(*targetShape);
120 if (!maybeShapeRatio) {
121 LDBG() <<
"--could not compute integral shape ratio -> BAIL";
124 if (llvm::all_of(*maybeShapeRatio, [](int64_t v) {
return v == 1; })) {
125 LDBG() <<
"--no unrolling needed -> SKIP";
128 LDBG() <<
"--found an integral shape ratio to unroll to -> SUCCESS";
136 llvm::to_vector(llvm::seq<int64_t>(0,
static_cast<int64_t
>(numLoops)));
137 if (
options.traversalOrderCallback !=
nullptr) {
138 std::optional<SmallVector<int64_t>> order =
139 options.traversalOrderCallback(op);
141 loopOrder = std::move(*order);
149 struct UnrollTransferReadPattern
157 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
160 if (readOp.getTransferRank() == 0)
162 if (readOp.getMask())
167 auto sourceVectorType = readOp.getVectorType();
174 arith::ConstantOp::create(rewriter, loc, sourceVectorType,
179 readOp.getIndices().end());
186 readOp.getPermutationMap(), loc, rewriter);
187 auto slicedRead = vector::TransferReadOp::create(
188 rewriter, loc, targetType, readOp.getBase(), indices,
189 readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
190 readOp.getInBoundsAttr());
192 result = rewriter.
createOrFold<vector::InsertStridedSliceOp>(
193 loc, slicedRead, result, elementOffsets, strides);
203 struct UnrollTransferWritePattern
211 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
214 if (writeOp.getTransferRank() == 0)
217 if (writeOp.getMask())
222 auto sourceVectorType = writeOp.getVectorType();
229 if (originalSize.size() != targetShape->size())
232 "expected source input vector rank to match target shape rank");
235 writeOp.getIndices().end());
242 loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
245 writeOp.getPermutationMap(), loc, rewriter);
246 Operation *slicedWrite = vector::TransferWriteOp::create(
247 rewriter, loc, slicedVector,
248 resultTensor ? resultTensor : writeOp.getBase(), indices,
249 writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
252 resultTensor = slicedWrite->
getResult(0);
255 rewriter.
replaceOp(writeOp, resultTensor);
265 struct OffsetMapInfo {
271 return static_cast<unsigned>(llvm::hash_combine_range(v));
280 struct UnrollContractionPattern
288 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
293 auto dstVecType = cast<VectorType>(contractOp.getResultType());
297 unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
298 AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex];
305 contractOp.getIteratorTypes().size(), contractOp,
options);
312 auto extractOperand = [&](
unsigned index,
Value operand,
318 slicesOperands[index] =
320 loc, operand, operandOffets, operandShape, operandStrides);
324 AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0];
327 extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);
330 AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1];
333 extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);
335 AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2];
340 auto *accIt = accCache.find(accOffets);
341 if (accIt != accCache.end())
342 slicesOperands[2] = accIt->second;
344 extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets);
348 auto targetType =
VectorType::get(dstShape, dstVecType.getElementType());
350 rewriter, loc, contractOp, slicesOperands, targetType);
356 accCache[dstOffets] = newOp->
getResult(0);
359 Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
361 for (
const auto &it : accCache) {
363 result = rewriter.
createOrFold<vector::InsertStridedSliceOp>(
364 loc, it.second, result, it.first, dstStrides);
374 struct UnrollMultiReductionPattern
382 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
384 auto resultType = reductionOp->getResult(0).getType();
385 if (resultType.isIntOrFloat()) {
387 "Unrolling scalars is not supported");
389 std::optional<SmallVector<int64_t>> targetShape =
398 Location loc = reductionOp.getLoc();
406 Value slicedOperand =
408 loc, reductionOp.getSource(), offsets, *targetShape,
410 operands.push_back(slicedOperand);
413 for (
size_t i : llvm::seq(
size_t(0), targetShape->size())) {
414 if (!reductionOp.isReducedDim(i)) {
415 destOffset.push_back(offsets[i]);
416 dstShape.push_back((*targetShape)[i]);
423 auto *accIt = accCache.find(destOffset);
424 if (accIt != accCache.end())
427 acc = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
428 loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
429 operands.push_back(acc);
431 dstShape, reductionOp.getSourceVectorType().getElementType());
433 operands, targetType);
435 accCache[destOffset] = result;
438 Value result = arith::ConstantOp::create(
439 rewriter, loc, reductionOp.getDestType(),
441 for (
const auto &it : accCache) {
443 result = rewriter.
createOrFold<vector::InsertStridedSliceOp>(
444 loc, it.second, result, it.first, dstStrides);
461 LogicalResult matchAndRewrite(
Operation *op,
468 int64_t targetShapeRank = targetShape->size();
471 *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
472 int64_t originalShapeRank = originalSize.size();
478 int64_t rankDiff = originalShapeRank - targetShapeRank;
479 std::fill(adjustedTargetShape.begin(),
480 adjustedTargetShape.begin() + rankDiff, 1);
481 std::copy(targetShape->begin(), targetShape->end(),
482 adjustedTargetShape.begin() + rankDiff);
484 int64_t adjustedTargetShapeRank = adjustedTargetShape.size();
486 Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
489 VectorType unrolledVecType =
497 auto vecType = dyn_cast<VectorType>(operand.get().getType());
499 extractOperands.push_back(operand.get());
503 loc, operand.get(), offsets, adjustedTargetShape, strides);
506 if (adjustedTargetShapeRank > targetShapeRank) {
511 extractOperands.push_back(extracted);
515 rewriter, loc, op, extractOperands, unrolledVecType);
521 (adjustedTargetShapeRank > targetShapeRank)
525 result = rewriter.
createOrFold<vector::InsertStridedSliceOp>(
526 loc, computeResult, result, offsets, insertStrides);
536 struct UnrollReductionPattern :
public OpRewritePattern<vector::ReductionOp> {
543 LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
545 std::optional<SmallVector<int64_t>> targetShape =
552 Location loc = reductionOp.getLoc();
553 Value accumulator =
nullptr;
557 Value slicedOperand =
559 loc, reductionOp.getVector(), offsets, *targetShape, strides);
561 rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
562 Value result = newOp->getResult(0);
566 accumulator = result;
570 accumulator, result);
574 rewriter.
replaceOp(reductionOp, accumulator);
582 struct UnrollTransposePattern :
public OpRewritePattern<vector::TransposeOp> {
589 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
591 if (transposeOp.getResultVectorType().getRank() == 0)
596 auto originalVectorType = transposeOp.getResultVectorType();
598 Location loc = transposeOp.getLoc();
603 arith::ConstantOp::create(rewriter, loc, originalVectorType,
614 permutedOffsets[indices.value()] = elementOffsets[indices.index()];
615 permutedShape[indices.value()] = (*targetShape)[indices.index()];
617 Value slicedOperand =
619 loc, transposeOp.getVector(), permutedOffsets, permutedShape,
622 loc, slicedOperand, permutation);
623 result = rewriter.
createOrFold<vector::InsertStridedSliceOp>(
624 loc, transposedSlice, result, elementOffsets, strides);
641 LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
643 VectorType sourceVectorType = gatherOp.getVectorType();
644 if (sourceVectorType.getRank() == 0)
655 arith::ConstantOp::create(rewriter, loc, sourceVectorType,
668 loc, gatherOp.getIndices(), elementOffsets, *targetShape, strides);
670 loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
671 Value passThruSubVec =
673 loc, gatherOp.getPassThru(), elementOffsets, *targetShape,
675 auto slicedGather = vector::GatherOp::create(
676 rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getOffsets(),
677 indexSubVec, maskSubVec, passThruSubVec);
679 result = rewriter.
createOrFold<vector::InsertStridedSliceOp>(
680 loc, slicedGather, result, elementOffsets, strides);
696 LogicalResult matchAndRewrite(vector::LoadOp loadOp,
698 VectorType vecType = loadOp.getVectorType();
708 Value result = arith::ConstantOp::create(rewriter, loc, vecType,
721 Value slicedLoad = vector::LoadOp::create(rewriter, loc, targetVecType,
722 loadOp.getBase(), indices);
723 result = rewriter.
createOrFold<vector::InsertStridedSliceOp>(
724 loc, slicedLoad, result, offsets, strides);
740 LogicalResult matchAndRewrite(vector::StoreOp storeOp,
742 VectorType vecType = storeOp.getVectorType();
752 Value base = storeOp.getBase();
753 Value vector = storeOp.getValueToStore();
763 loc, vector, offsets, *targetShape, strides);
764 vector::StoreOp::create(rewriter, loc, slice, base, indices);
774 struct UnrollBroadcastPattern :
public OpRewritePattern<vector::BroadcastOp> {
781 LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
787 Location loc = broadcastOp.getLoc();
788 VectorType srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
789 VectorType resType = broadcastOp.getResultVectorType();
790 VectorType targetType =
791 resType.cloneWith(*targetShape, resType.getElementType());
792 Value result = arith::ConstantOp::create(rewriter, loc, resType,
803 newSrc = broadcastOp.getSource();
806 int64_t rank = srcType.getRank();
812 for (int64_t i = 0; i < rank; ++i) {
813 if (srcType.getDimSize(i) == 1) {
818 newSrc = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
819 loc, broadcastOp.getSource(), srcOffsets, srcShape, srcStrides);
825 result = rewriter.
createOrFold<vector::InsertStridedSliceOp>(
826 loc, newOp->
getResult(0), result, offsets, strides);
854 struct UnrollToElements final :
public OpRewritePattern<vector::ToElementsOp> {
861 LogicalResult matchAndRewrite(vector::ToElementsOp op,
865 FailureOr<SmallVector<Value>> result =
873 for (
Value vector : vectors) {
875 vector::ToElementsOp::create(rewriter, op.getLoc(), vector);
876 llvm::append_range(results, subElements.getResults());
918 LogicalResult matchAndRewrite(vector::StepOp stepOp,
920 std::optional<SmallVector<int64_t>> targetShape =
925 VectorType vecType = stepOp.getType();
926 if (vecType.isScalable()) {
930 int64_t originalSize = vecType.getShape()[0];
934 Value result = arith::ConstantOp::create(rewriter, loc, vecType,
939 Value baseStep = vector::StepOp::create(rewriter, loc, targetVecType);
942 Value bcastOffset = arith::ConstantOp::create(
943 rewriter, loc, targetVecType,
948 arith::AddIOp::create(rewriter, loc, baseStep, bcastOffset);
950 result = rewriter.
createOrFold<vector::InsertStridedSliceOp>(
951 loc, tileStep, result, offsets, strides);
985 LogicalResult matchAndRewrite(vector::FromElementsOp op,
990 VectorType subTy, int64_t index) {
991 size_t subTyNumElements = subTy.getNumElements();
992 assert((index + 1) * subTyNumElements <= allElements.size() &&
995 allElements.slice(index * subTyNumElements, subTyNumElements);
996 return vector::FromElementsOp::create(rewriter, loc, subTy, subElements);
1008 void mlir::vector::populateVectorUnrollPatterns(
1011 patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
1012 UnrollContractionPattern, UnrollElementwisePattern,
1013 UnrollReductionPattern, UnrollMultiReductionPattern,
1014 UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
1015 UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
1016 UnrollToElements, UnrollStepPattern>(
patterns.getContext(),
1020 void mlir::vector::populateVectorToElementsUnrollPatterns(
1026 void mlir::vector::populateVectorFromElementsUnrollPatterns(
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static llvm::ManagedStatic< PassManagerOptions > options
static SmallVector< Value > sliceTransferIndices(ArrayRef< int64_t > elementOffsets, ArrayRef< Value > indices, AffineMap permutationMap, Location loc, OpBuilder &builder)
Compute the indices of the slice index for a transfer op.
static SmallVector< Value > sliceLoadStoreIndices(PatternRewriter &rewriter, Location loc, OperandRange originalIndices, ArrayRef< int64_t > offsets)
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
static Operation * cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static SmallVector< int64_t > getUnrollOrder(unsigned numLoops, Operation *op, const vector::UnrollVectorOptions &options)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
This class implements the operand iterators for the Operation class.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< OpOperand > getOpOperands()
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
FailureOr< SmallVector< Value > > unrollVectorValue(TypedValue< VectorType >, RewriterBase &)
Generic utility for unrolling values of type vector<NxAxBx...> to N values of type vector<AxBx....
LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter, UnrollVectorOpFn unrollFn)
Include the generated interface declarations.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options that control the vector unrolling.