18 #include "llvm/ADT/MapVector.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/Support/DebugLog.h"
21 #include "llvm/Support/InterleavedRange.h"
24 #define DEBUG_TYPE "vector-unroll"
37 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr))
38 return constExpr.getValue() == 0;
44 if (isBroadcast(dim.value()))
46 unsigned pos = cast<AffineDimExpr>(dim.value()).getPosition();
51 affine::AffineApplyOp::create(builder, loc, map, indices[pos]);
63 assert(offsets.size() <= originalIndices.size() &&
64 "Offsets should not exceed the number of original indices");
67 auto start = indices.size() - offsets.size();
70 indices[start + i] = arith::AddIOp::create(
71 rewriter, loc, originalIndices[start + i],
90 static std::optional<SmallVector<int64_t>>
94 LDBG() <<
"--no filter constraint -> BAIL";
98 "vector unrolling expects the native shape or native"
99 "shape call back function to be set");
100 auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
101 if (!unrollableVectorOp) {
102 LDBG() <<
"--not an unrollable op -> BAIL";
105 auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
106 if (!maybeUnrollShape) {
107 LDBG() <<
"--could not get shape of op " << *op <<
" -> BAIL";
110 LDBG() <<
"--vector op shape: " << llvm::interleaved(*maybeUnrollShape);
112 std::optional<SmallVector<int64_t>> targetShape =
options.nativeShape(op);
114 LDBG() <<
"--no unrolling target shape defined " << *op <<
"-> SKIP";
117 LDBG() <<
"--target shape: " << llvm::interleaved(*targetShape);
120 if (!maybeShapeRatio) {
121 LDBG() <<
"--could not compute integral shape ratio -> BAIL";
124 if (llvm::all_of(*maybeShapeRatio, [](int64_t v) {
return v == 1; })) {
125 LDBG() <<
"--no unrolling needed -> SKIP";
128 LDBG() <<
"--found an integral shape ratio to unroll to -> SUCCESS";
136 llvm::to_vector(llvm::seq<int64_t>(0,
static_cast<int64_t
>(numLoops)));
137 if (
options.traversalOrderCallback !=
nullptr) {
138 std::optional<SmallVector<int64_t>> order =
139 options.traversalOrderCallback(op);
141 loopOrder = std::move(*order);
149 struct UnrollTransferReadPattern
157 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
160 if (readOp.getTransferRank() == 0)
162 if (readOp.getMask())
167 auto sourceVectorType = readOp.getVectorType();
174 arith::ConstantOp::create(rewriter, loc, sourceVectorType,
179 readOp.getIndices().end());
186 readOp.getPermutationMap(), loc, rewriter);
187 auto slicedRead = vector::TransferReadOp::create(
188 rewriter, loc, targetType, readOp.getBase(), indices,
189 readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
190 readOp.getInBoundsAttr());
192 result = rewriter.
createOrFold<vector::InsertStridedSliceOp>(
193 loc, slicedRead, result, elementOffsets, strides);
203 struct UnrollTransferWritePattern
211 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
214 if (writeOp.getTransferRank() == 0)
217 if (writeOp.getMask())
222 auto sourceVectorType = writeOp.getVectorType();
229 if (originalSize.size() != targetShape->size())
232 "expected source input vector rank to match target shape rank");
235 writeOp.getIndices().end());
242 loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
245 writeOp.getPermutationMap(), loc, rewriter);
246 Operation *slicedWrite = vector::TransferWriteOp::create(
247 rewriter, loc, slicedVector,
248 resultTensor ? resultTensor : writeOp.getBase(), indices,
249 writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
252 resultTensor = slicedWrite->
getResult(0);
255 rewriter.
replaceOp(writeOp, resultTensor);
265 struct OffsetMapInfo {
271 return static_cast<unsigned>(llvm::hash_combine_range(v));
280 struct UnrollContractionPattern
288 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
293 auto dstVecType = cast<VectorType>(contractOp.getResultType());
297 unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
298 AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex];
305 contractOp.getIteratorTypes().size(), contractOp,
options);
312 auto extractOperand = [&](
unsigned index,
Value operand,
318 slicesOperands[index] =
320 loc, operand, operandOffets, operandShape, operandStrides);
324 AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0];
327 extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);
330 AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1];
333 extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);
335 AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2];
340 auto *accIt = accCache.find(accOffets);
341 if (accIt != accCache.end())
342 slicesOperands[2] = accIt->second;
344 extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets);
348 auto targetType =
VectorType::get(dstShape, dstVecType.getElementType());
350 rewriter, loc, contractOp, slicesOperands, targetType);
356 accCache[dstOffets] = newOp->
getResult(0);
359 Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
361 for (
const auto &it : accCache) {
363 result = rewriter.
createOrFold<vector::InsertStridedSliceOp>(
364 loc, it.second, result, it.first, dstStrides);
374 struct UnrollMultiReductionPattern
382 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
384 auto resultType = reductionOp->getResult(0).getType();
385 if (resultType.isIntOrFloat()) {
387 "Unrolling scalars is not supported");
389 std::optional<SmallVector<int64_t>> targetShape =
398 Location loc = reductionOp.getLoc();
406 Value slicedOperand =
408 loc, reductionOp.getSource(), offsets, *targetShape,
410 operands.push_back(slicedOperand);
413 for (
size_t i : llvm::seq(
size_t(0), targetShape->size())) {
414 if (!reductionOp.isReducedDim(i)) {
415 destOffset.push_back(offsets[i]);
416 dstShape.push_back((*targetShape)[i]);
423 auto *accIt = accCache.find(destOffset);
424 if (accIt != accCache.end())
427 acc = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
428 loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
429 operands.push_back(acc);
431 dstShape, reductionOp.getSourceVectorType().getElementType());
433 operands, targetType);
435 accCache[destOffset] = result;
438 Value result = arith::ConstantOp::create(
439 rewriter, loc, reductionOp.getDestType(),
441 for (
const auto &it : accCache) {
443 result = rewriter.
createOrFold<vector::InsertStridedSliceOp>(
444 loc, it.second, result, it.first, dstStrides);
461 LogicalResult matchAndRewrite(
Operation *op,
470 *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
474 if (originalSize.size() != targetShape->size())
476 op,
"expected input vector rank to match target shape rank");
479 Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
482 VectorType newVecType =
490 auto vecType = dyn_cast<VectorType>(operand.get().getType());
492 extractOperands.push_back(operand.get());
495 extractOperands.push_back(
497 loc, operand.get(), offsets, *targetShape, strides));
500 rewriter, loc, op, extractOperands, newVecType);
501 result = rewriter.
createOrFold<vector::InsertStridedSliceOp>(
502 loc, newOp->
getResult(0), result, offsets, strides);
512 struct UnrollReductionPattern :
public OpRewritePattern<vector::ReductionOp> {
519 LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
521 std::optional<SmallVector<int64_t>> targetShape =
528 Location loc = reductionOp.getLoc();
529 Value accumulator =
nullptr;
533 Value slicedOperand =
535 loc, reductionOp.getVector(), offsets, *targetShape, strides);
537 rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
538 Value result = newOp->getResult(0);
542 accumulator = result;
546 accumulator, result);
550 rewriter.
replaceOp(reductionOp, accumulator);
558 struct UnrollTransposePattern :
public OpRewritePattern<vector::TransposeOp> {
565 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
567 if (transposeOp.getResultVectorType().getRank() == 0)
572 auto originalVectorType = transposeOp.getResultVectorType();
574 Location loc = transposeOp.getLoc();
579 arith::ConstantOp::create(rewriter, loc, originalVectorType,
590 permutedOffsets[indices.value()] = elementOffsets[indices.index()];
591 permutedShape[indices.value()] = (*targetShape)[indices.index()];
593 Value slicedOperand =
595 loc, transposeOp.getVector(), permutedOffsets, permutedShape,
598 loc, slicedOperand, permutation);
599 result = rewriter.
createOrFold<vector::InsertStridedSliceOp>(
600 loc, transposedSlice, result, elementOffsets, strides);
617 LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
619 VectorType sourceVectorType = gatherOp.getVectorType();
620 if (sourceVectorType.getRank() == 0)
631 arith::ConstantOp::create(rewriter, loc, sourceVectorType,
644 loc, gatherOp.getIndices(), elementOffsets, *targetShape, strides);
646 loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
647 Value passThruSubVec =
649 loc, gatherOp.getPassThru(), elementOffsets, *targetShape,
651 auto slicedGather = vector::GatherOp::create(
652 rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getOffsets(),
653 indexSubVec, maskSubVec, passThruSubVec);
655 result = rewriter.
createOrFold<vector::InsertStridedSliceOp>(
656 loc, slicedGather, result, elementOffsets, strides);
672 LogicalResult matchAndRewrite(vector::LoadOp loadOp,
674 VectorType vecType = loadOp.getVectorType();
684 Value result = arith::ConstantOp::create(rewriter, loc, vecType,
697 Value slicedLoad = vector::LoadOp::create(rewriter, loc, targetVecType,
698 loadOp.getBase(), indices);
699 result = rewriter.
createOrFold<vector::InsertStridedSliceOp>(
700 loc, slicedLoad, result, offsets, strides);
716 LogicalResult matchAndRewrite(vector::StoreOp storeOp,
718 VectorType vecType = storeOp.getVectorType();
728 Value base = storeOp.getBase();
729 Value vector = storeOp.getValueToStore();
739 loc, vector, offsets, *targetShape, strides);
740 vector::StoreOp::create(rewriter, loc, slice, base, indices);
750 struct UnrollBroadcastPattern :
public OpRewritePattern<vector::BroadcastOp> {
757 LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
763 Location loc = broadcastOp.getLoc();
764 VectorType srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
765 VectorType resType = broadcastOp.getResultVectorType();
766 VectorType targetType =
767 resType.cloneWith(*targetShape, resType.getElementType());
768 Value result = arith::ConstantOp::create(rewriter, loc, resType,
779 newSrc = broadcastOp.getSource();
782 int64_t rank = srcType.getRank();
788 for (int64_t i = 0; i < rank; ++i) {
789 if (srcType.getDimSize(i) == 1) {
794 newSrc = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
795 loc, broadcastOp.getSource(), srcOffsets, srcShape, srcStrides);
801 result = rewriter.
createOrFold<vector::InsertStridedSliceOp>(
802 loc, newOp->
getResult(0), result, offsets, strides);
830 struct UnrollToElements final :
public OpRewritePattern<vector::ToElementsOp> {
837 LogicalResult matchAndRewrite(vector::ToElementsOp op,
841 FailureOr<SmallVector<Value>> result =
849 for (
Value vector : vectors) {
851 vector::ToElementsOp::create(rewriter, op.getLoc(), vector);
852 llvm::append_range(results, subElements.getResults());
894 LogicalResult matchAndRewrite(vector::StepOp stepOp,
896 std::optional<SmallVector<int64_t>> targetShape =
901 VectorType vecType = stepOp.getType();
902 if (vecType.isScalable()) {
906 int64_t originalSize = vecType.getShape()[0];
910 Value result = arith::ConstantOp::create(rewriter, loc, vecType,
915 Value baseStep = vector::StepOp::create(rewriter, loc, targetVecType);
918 Value bcastOffset = arith::ConstantOp::create(
919 rewriter, loc, targetVecType,
924 arith::AddIOp::create(rewriter, loc, baseStep, bcastOffset);
926 result = rewriter.
createOrFold<vector::InsertStridedSliceOp>(
927 loc, tileStep, result, offsets, strides);
961 LogicalResult matchAndRewrite(vector::FromElementsOp op,
966 VectorType subTy, int64_t index) {
967 size_t subTyNumElements = subTy.getNumElements();
968 assert((index + 1) * subTyNumElements <= allElements.size() &&
971 allElements.slice(index * subTyNumElements, subTyNumElements);
972 return vector::FromElementsOp::create(rewriter, loc, subTy, subElements);
984 void mlir::vector::populateVectorUnrollPatterns(
987 patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
988 UnrollContractionPattern, UnrollElementwisePattern,
989 UnrollReductionPattern, UnrollMultiReductionPattern,
990 UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
991 UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
992 UnrollToElements, UnrollStepPattern>(
patterns.getContext(),
996 void mlir::vector::populateVectorToElementsUnrollPatterns(
1002 void mlir::vector::populateVectorFromElementsUnrollPatterns(
static llvm::ManagedStatic< PassManagerOptions > options
static SmallVector< Value > sliceTransferIndices(ArrayRef< int64_t > elementOffsets, ArrayRef< Value > indices, AffineMap permutationMap, Location loc, OpBuilder &builder)
Compute the indices of the slice index for a transfer op.
static SmallVector< Value > sliceLoadStoreIndices(PatternRewriter &rewriter, Location loc, OperandRange originalIndices, ArrayRef< int64_t > offsets)
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
static Operation * cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static SmallVector< int64_t > getUnrollOrder(unsigned numLoops, Operation *op, const vector::UnrollVectorOptions &options)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
This class implements the operand iterators for the Operation class.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< OpOperand > getOpOperands()
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
FailureOr< SmallVector< Value > > unrollVectorValue(TypedValue< VectorType >, RewriterBase &)
Generic utility for unrolling values of type vector<NxAxBx...> to N values of type vector<AxBx....
LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter, UnrollVectorOpFn unrollFn)
Include the generated interface declarations.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options that control the vector unrolling.