17 #include "llvm/ADT/MapVector.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/InterleavedRange.h"
23 #define DEBUG_TYPE "vector-unroll"
24 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
25 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
38 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr))
39 return constExpr.getValue() == 0;
45 if (isBroadcast(dim.value()))
47 unsigned pos = cast<AffineDimExpr>(dim.value()).getPosition();
52 builder.
create<affine::AffineApplyOp>(loc, map, indices[pos]);
69 static std::optional<SmallVector<int64_t>>
73 if (
options.filterConstraint && failed(
options.filterConstraint(op))) {
74 LDBG(
"--no filter constraint -> BAIL");
78 "vector unrolling expects the native shape or native"
79 "shape call back function to be set");
80 auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
81 if (!unrollableVectorOp) {
82 LDBG(
"--not an unrollable op -> BAIL");
85 auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
86 if (!maybeUnrollShape) {
87 LDBG(
"--could not get shape of op " << *op <<
" -> BAIL");
90 LDBG(
"--vector op shape: " << llvm::interleaved(*maybeUnrollShape));
92 std::optional<SmallVector<int64_t>> targetShape =
options.nativeShape(op);
94 LDBG(
"--no unrolling target shape defined " << *op <<
"-> SKIP");
97 LDBG(
"--target shape: " << llvm::interleaved(*targetShape));
100 if (!maybeShapeRatio) {
101 LDBG(
"--could not compute integral shape ratio -> BAIL");
104 if (llvm::all_of(*maybeShapeRatio, [](int64_t v) {
return v == 1; })) {
105 LDBG(
"--no unrolling needed -> SKIP");
108 LDBG(
"--found an integral shape ratio to unroll to -> SUCCESS");
116 llvm::to_vector(llvm::seq<int64_t>(0,
static_cast<int64_t
>(numLoops)));
117 if (
options.traversalOrderCallback !=
nullptr) {
118 std::optional<SmallVector<int64_t>> order =
119 options.traversalOrderCallback(op);
121 loopOrder = std::move(*order);
129 struct UnrollTransferReadPattern
137 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
140 if (readOp.getTransferRank() == 0)
142 if (readOp.getMask())
147 auto sourceVectorType = readOp.getVectorType();
154 loc, sourceVectorType, rewriter.
getZeroAttr(sourceVectorType));
158 readOp.getIndices().end());
165 readOp.getPermutationMap(), loc, rewriter);
166 auto slicedRead = rewriter.
create<vector::TransferReadOp>(
167 loc, targetType, readOp.getSource(), indices,
168 readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
169 readOp.getInBoundsAttr());
171 result = rewriter.
createOrFold<vector::InsertStridedSliceOp>(
172 loc, slicedRead, result, elementOffsets, strides);
182 struct UnrollTransferWritePattern
190 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
193 if (writeOp.getTransferRank() == 0)
196 if (writeOp.getMask())
201 auto sourceVectorType = writeOp.getVectorType();
206 writeOp.getIndices().end());
213 loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
216 writeOp.getPermutationMap(), loc, rewriter);
218 loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource(),
219 indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
222 resultTensor = slicedWrite->
getResult(0);
225 rewriter.
replaceOp(writeOp, resultTensor);
235 struct OffsetMapInfo {
241 return static_cast<unsigned>(llvm::hash_combine_range(v));
250 struct UnrollContractionPattern
258 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
263 auto dstVecType = cast<VectorType>(contractOp.getResultType());
267 unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
268 AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex];
275 contractOp.getIteratorTypes().size(), contractOp,
options);
282 auto extractOperand = [&](
unsigned index,
Value operand,
288 slicesOperands[index] =
290 loc, operand, operandOffets, operandShape, operandStrides);
294 AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0];
297 extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);
300 AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1];
303 extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);
305 AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2];
310 auto *accIt = accCache.find(accOffets);
311 if (accIt != accCache.end())
312 slicesOperands[2] = accIt->second;
314 extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets);
318 auto targetType =
VectorType::get(dstShape, dstVecType.getElementType());
320 rewriter, loc, contractOp, slicesOperands, targetType);
326 accCache[dstOffets] = newOp->
getResult(0);
330 loc, dstVecType, rewriter.
getZeroAttr(dstVecType));
331 for (
const auto &it : accCache) {
333 result = rewriter.
createOrFold<vector::InsertStridedSliceOp>(
334 loc, it.second, result, it.first, dstStrides);
344 struct UnrollMultiReductionPattern
352 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
354 auto resultType = reductionOp->getResult(0).getType();
355 if (resultType.isIntOrFloat()) {
357 "Unrolling scalars is not supported");
359 std::optional<SmallVector<int64_t>> targetShape =
368 Location loc = reductionOp.getLoc();
376 Value slicedOperand =
378 loc, reductionOp.getSource(), offsets, *targetShape,
380 operands.push_back(slicedOperand);
383 for (
size_t i : llvm::seq(
size_t(0), targetShape->size())) {
384 if (!reductionOp.isReducedDim(i)) {
385 destOffset.push_back(offsets[i]);
386 dstShape.push_back((*targetShape)[i]);
393 auto *accIt = accCache.find(destOffset);
394 if (accIt != accCache.end())
397 acc = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
398 loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
399 operands.push_back(acc);
401 dstShape, reductionOp.getSourceVectorType().getElementType());
403 operands, targetType);
405 accCache[destOffset] = result;
409 loc, reductionOp.getDestType(),
411 for (
const auto &it : accCache) {
413 result = rewriter.
createOrFold<vector::InsertStridedSliceOp>(
414 loc, it.second, result, it.first, dstStrides);
431 LogicalResult matchAndRewrite(
Operation *op,
440 *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
444 if (originalSize.size() != targetShape->size())
446 op,
"expected input vector rank to match target shape rank");
450 loc, dstVecType, rewriter.
getZeroAttr(dstVecType));
452 VectorType newVecType =
460 auto vecType = dyn_cast<VectorType>(operand.get().getType());
462 extractOperands.push_back(operand.get());
465 extractOperands.push_back(
467 loc, operand.get(), offsets, *targetShape, strides));
470 rewriter, loc, op, extractOperands, newVecType);
471 result = rewriter.
createOrFold<vector::InsertStridedSliceOp>(
472 loc, newOp->
getResult(0), result, offsets, strides);
482 struct UnrollReductionPattern :
public OpRewritePattern<vector::ReductionOp> {
489 LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
491 std::optional<SmallVector<int64_t>> targetShape =
498 Location loc = reductionOp.getLoc();
499 Value accumulator =
nullptr;
503 Value slicedOperand =
505 loc, reductionOp.getVector(), offsets, *targetShape, strides);
507 rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
508 Value result = newOp->getResult(0);
512 accumulator = result;
516 accumulator, result);
520 rewriter.
replaceOp(reductionOp, accumulator);
528 struct UnrollTransposePattern :
public OpRewritePattern<vector::TransposeOp> {
535 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
537 if (transposeOp.getResultVectorType().getRank() == 0)
542 auto originalVectorType = transposeOp.getResultVectorType();
544 Location loc = transposeOp.getLoc();
549 loc, originalVectorType, rewriter.
getZeroAttr(originalVectorType));
559 permutedOffsets[indices.value()] = elementOffsets[indices.index()];
560 permutedShape[indices.value()] = (*targetShape)[indices.index()];
562 Value slicedOperand =
564 loc, transposeOp.getVector(), permutedOffsets, permutedShape,
567 loc, slicedOperand, permutation);
568 result = rewriter.
createOrFold<vector::InsertStridedSliceOp>(
569 loc, transposedSlice, result, elementOffsets, strides);
586 LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
588 VectorType sourceVectorType = gatherOp.getVectorType();
589 if (sourceVectorType.getRank() == 0)
600 loc, sourceVectorType, rewriter.
getZeroAttr(sourceVectorType));
612 loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides);
614 loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
615 Value passThruSubVec =
617 loc, gatherOp.getPassThru(), elementOffsets, *targetShape,
619 auto slicedGather = rewriter.
create<vector::GatherOp>(
620 loc, targetType, gatherOp.getBase(), gatherOp.getIndices(),
621 indexSubVec, maskSubVec, passThruSubVec);
623 result = rewriter.
createOrFold<vector::InsertStridedSliceOp>(
624 loc, slicedGather, result, elementOffsets, strides);
636 void mlir::vector::populateVectorUnrollPatterns(
639 patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
640 UnrollContractionPattern, UnrollElementwisePattern,
641 UnrollReductionPattern, UnrollMultiReductionPattern,
642 UnrollTransposePattern, UnrollGatherPattern>(
static llvm::ManagedStatic< PassManagerOptions > options
static SmallVector< Value > sliceTransferIndices(ArrayRef< int64_t > elementOffsets, ArrayRef< Value > indices, AffineMap permutationMap, Location loc, OpBuilder &builder)
Compute the indices of the slice index for a transfer op.
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
static Operation * cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static SmallVector< int64_t > getUnrollOrder(unsigned numLoops, Operation *op, const vector::UnrollVectorOptions &options)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< OpOperand > getOpOperands()
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
Include the generated interface declarations.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
const FrozenRewritePatternSet & patterns
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options that control the vector unrolling.