17 #include "llvm/ADT/MapVector.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/Support/DebugLog.h"
20 #include "llvm/Support/InterleavedRange.h"
23 #define DEBUG_TYPE "vector-unroll"
36 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr))
37 return constExpr.getValue() == 0;
43 if (isBroadcast(dim.value()))
45 unsigned pos = cast<AffineDimExpr>(dim.value()).getPosition();
50 affine::AffineApplyOp::create(builder, loc, map, indices[pos]);
62 assert(offsets.size() <= originalIndices.size() &&
63 "Offsets should not exceed the number of original indices");
66 auto start = indices.size() - offsets.size();
69 indices[start + i] = arith::AddIOp::create(
70 rewriter, loc, originalIndices[start + i],
89 static std::optional<SmallVector<int64_t>>
93 LDBG() <<
"--no filter constraint -> BAIL";
97 "vector unrolling expects the native shape or native"
98 "shape call back function to be set");
99 auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
100 if (!unrollableVectorOp) {
101 LDBG() <<
"--not an unrollable op -> BAIL";
104 auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
105 if (!maybeUnrollShape) {
106 LDBG() <<
"--could not get shape of op " << *op <<
" -> BAIL";
109 LDBG() <<
"--vector op shape: " << llvm::interleaved(*maybeUnrollShape);
111 std::optional<SmallVector<int64_t>> targetShape =
options.nativeShape(op);
113 LDBG() <<
"--no unrolling target shape defined " << *op <<
"-> SKIP";
116 LDBG() <<
"--target shape: " << llvm::interleaved(*targetShape);
119 if (!maybeShapeRatio) {
120 LDBG() <<
"--could not compute integral shape ratio -> BAIL";
123 if (llvm::all_of(*maybeShapeRatio, [](int64_t v) {
return v == 1; })) {
124 LDBG() <<
"--no unrolling needed -> SKIP";
127 LDBG() <<
"--found an integral shape ratio to unroll to -> SUCCESS";
135 llvm::to_vector(llvm::seq<int64_t>(0,
static_cast<int64_t
>(numLoops)));
136 if (
options.traversalOrderCallback !=
nullptr) {
137 std::optional<SmallVector<int64_t>> order =
138 options.traversalOrderCallback(op);
140 loopOrder = std::move(*order);
148 struct UnrollTransferReadPattern
156 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
159 if (readOp.getTransferRank() == 0)
161 if (readOp.getMask())
166 auto sourceVectorType = readOp.getVectorType();
173 arith::ConstantOp::create(rewriter, loc, sourceVectorType,
178 readOp.getIndices().end());
185 readOp.getPermutationMap(), loc, rewriter);
186 auto slicedRead = vector::TransferReadOp::create(
187 rewriter, loc, targetType, readOp.getBase(), indices,
188 readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
189 readOp.getInBoundsAttr());
191 result = rewriter.
createOrFold<vector::InsertStridedSliceOp>(
192 loc, slicedRead, result, elementOffsets, strides);
202 struct UnrollTransferWritePattern
210 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
213 if (writeOp.getTransferRank() == 0)
216 if (writeOp.getMask())
221 auto sourceVectorType = writeOp.getVectorType();
228 if (originalSize.size() != targetShape->size())
231 "expected source input vector rank to match target shape rank");
234 writeOp.getIndices().end());
241 loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
244 writeOp.getPermutationMap(), loc, rewriter);
245 Operation *slicedWrite = vector::TransferWriteOp::create(
246 rewriter, loc, slicedVector,
247 resultTensor ? resultTensor : writeOp.getBase(), indices,
248 writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
251 resultTensor = slicedWrite->
getResult(0);
254 rewriter.
replaceOp(writeOp, resultTensor);
264 struct OffsetMapInfo {
270 return static_cast<unsigned>(llvm::hash_combine_range(v));
279 struct UnrollContractionPattern
287 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
292 auto dstVecType = cast<VectorType>(contractOp.getResultType());
296 unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
297 AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex];
304 contractOp.getIteratorTypes().size(), contractOp,
options);
311 auto extractOperand = [&](
unsigned index,
Value operand,
317 slicesOperands[index] =
319 loc, operand, operandOffets, operandShape, operandStrides);
323 AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0];
326 extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);
329 AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1];
332 extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);
334 AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2];
339 auto *accIt = accCache.find(accOffets);
340 if (accIt != accCache.end())
341 slicesOperands[2] = accIt->second;
343 extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets);
347 auto targetType =
VectorType::get(dstShape, dstVecType.getElementType());
349 rewriter, loc, contractOp, slicesOperands, targetType);
355 accCache[dstOffets] = newOp->
getResult(0);
358 Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
360 for (
const auto &it : accCache) {
362 result = rewriter.
createOrFold<vector::InsertStridedSliceOp>(
363 loc, it.second, result, it.first, dstStrides);
373 struct UnrollMultiReductionPattern
381 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
383 auto resultType = reductionOp->getResult(0).getType();
384 if (resultType.isIntOrFloat()) {
386 "Unrolling scalars is not supported");
388 std::optional<SmallVector<int64_t>> targetShape =
397 Location loc = reductionOp.getLoc();
405 Value slicedOperand =
407 loc, reductionOp.getSource(), offsets, *targetShape,
409 operands.push_back(slicedOperand);
412 for (
size_t i : llvm::seq(
size_t(0), targetShape->size())) {
413 if (!reductionOp.isReducedDim(i)) {
414 destOffset.push_back(offsets[i]);
415 dstShape.push_back((*targetShape)[i]);
422 auto *accIt = accCache.find(destOffset);
423 if (accIt != accCache.end())
426 acc = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
427 loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
428 operands.push_back(acc);
430 dstShape, reductionOp.getSourceVectorType().getElementType());
432 operands, targetType);
434 accCache[destOffset] = result;
437 Value result = arith::ConstantOp::create(
438 rewriter, loc, reductionOp.getDestType(),
440 for (
const auto &it : accCache) {
442 result = rewriter.
createOrFold<vector::InsertStridedSliceOp>(
443 loc, it.second, result, it.first, dstStrides);
460 LogicalResult matchAndRewrite(
Operation *op,
469 *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
473 if (originalSize.size() != targetShape->size())
475 op,
"expected input vector rank to match target shape rank");
478 Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
481 VectorType newVecType =
489 auto vecType = dyn_cast<VectorType>(operand.get().getType());
491 extractOperands.push_back(operand.get());
494 extractOperands.push_back(
496 loc, operand.get(), offsets, *targetShape, strides));
499 rewriter, loc, op, extractOperands, newVecType);
500 result = rewriter.
createOrFold<vector::InsertStridedSliceOp>(
501 loc, newOp->
getResult(0), result, offsets, strides);
511 struct UnrollReductionPattern :
public OpRewritePattern<vector::ReductionOp> {
518 LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
520 std::optional<SmallVector<int64_t>> targetShape =
527 Location loc = reductionOp.getLoc();
528 Value accumulator =
nullptr;
532 Value slicedOperand =
534 loc, reductionOp.getVector(), offsets, *targetShape, strides);
536 rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
537 Value result = newOp->getResult(0);
541 accumulator = result;
545 accumulator, result);
549 rewriter.
replaceOp(reductionOp, accumulator);
557 struct UnrollTransposePattern :
public OpRewritePattern<vector::TransposeOp> {
564 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
566 if (transposeOp.getResultVectorType().getRank() == 0)
571 auto originalVectorType = transposeOp.getResultVectorType();
573 Location loc = transposeOp.getLoc();
578 arith::ConstantOp::create(rewriter, loc, originalVectorType,
589 permutedOffsets[indices.value()] = elementOffsets[indices.index()];
590 permutedShape[indices.value()] = (*targetShape)[indices.index()];
592 Value slicedOperand =
594 loc, transposeOp.getVector(), permutedOffsets, permutedShape,
597 loc, slicedOperand, permutation);
598 result = rewriter.
createOrFold<vector::InsertStridedSliceOp>(
599 loc, transposedSlice, result, elementOffsets, strides);
616 LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
618 VectorType sourceVectorType = gatherOp.getVectorType();
619 if (sourceVectorType.getRank() == 0)
630 arith::ConstantOp::create(rewriter, loc, sourceVectorType,
643 loc, gatherOp.getIndices(), elementOffsets, *targetShape, strides);
645 loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
646 Value passThruSubVec =
648 loc, gatherOp.getPassThru(), elementOffsets, *targetShape,
650 auto slicedGather = vector::GatherOp::create(
651 rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getOffsets(),
652 indexSubVec, maskSubVec, passThruSubVec);
654 result = rewriter.
createOrFold<vector::InsertStridedSliceOp>(
655 loc, slicedGather, result, elementOffsets, strides);
671 LogicalResult matchAndRewrite(vector::LoadOp loadOp,
673 VectorType vecType = loadOp.getVectorType();
683 Value result = arith::ConstantOp::create(rewriter, loc, vecType,
696 Value slicedLoad = vector::LoadOp::create(rewriter, loc, targetVecType,
697 loadOp.getBase(), indices);
698 result = rewriter.
createOrFold<vector::InsertStridedSliceOp>(
699 loc, slicedLoad, result, offsets, strides);
715 LogicalResult matchAndRewrite(vector::StoreOp storeOp,
717 VectorType vecType = storeOp.getVectorType();
727 Value base = storeOp.getBase();
728 Value vector = storeOp.getValueToStore();
738 loc, vector, offsets, *targetShape, strides);
739 vector::StoreOp::create(rewriter, loc, slice, base, indices);
749 struct UnrollBroadcastPattern :
public OpRewritePattern<vector::BroadcastOp> {
756 LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
762 Location loc = broadcastOp.getLoc();
763 VectorType srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
764 VectorType resType = broadcastOp.getResultVectorType();
765 VectorType targetType =
766 resType.cloneWith(*targetShape, resType.getElementType());
767 Value result = arith::ConstantOp::create(rewriter, loc, resType,
778 newSrc = broadcastOp.getSource();
781 int64_t rank = srcType.getRank();
787 for (int64_t i = 0; i < rank; ++i) {
788 if (srcType.getDimSize(i) == 1) {
793 newSrc = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
794 loc, broadcastOp.getSource(), srcOffsets, srcShape, srcStrides);
800 result = rewriter.
createOrFold<vector::InsertStridedSliceOp>(
801 loc, newOp->
getResult(0), result, offsets, strides);
814 void mlir::vector::populateVectorUnrollPatterns(
817 patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
818 UnrollContractionPattern, UnrollElementwisePattern,
819 UnrollReductionPattern, UnrollMultiReductionPattern,
820 UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
821 UnrollStorePattern, UnrollBroadcastPattern>(
static llvm::ManagedStatic< PassManagerOptions > options
static SmallVector< Value > sliceTransferIndices(ArrayRef< int64_t > elementOffsets, ArrayRef< Value > indices, AffineMap permutationMap, Location loc, OpBuilder &builder)
Compute the indices of the slice index for a transfer op.
static SmallVector< Value > sliceLoadStoreIndices(PatternRewriter &rewriter, Location loc, OperandRange originalIndices, ArrayRef< int64_t > offsets)
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
static Operation * cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static SmallVector< int64_t > getUnrollOrder(unsigned numLoops, Operation *op, const vector::UnrollVectorOptions &options)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
This class implements the operand iterators for the Operation class.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< OpOperand > getOpOperands()
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
Include the generated interface declarations.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
const FrozenRewritePatternSet & patterns
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options that control the vector unrolling.