18 #include "llvm/ADT/MapVector.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/Support/Debug.h"
24 #define DEBUG_TYPE "vector-unroll"
25 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
26 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
39 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr))
40 return constExpr.getValue() == 0;
46 if (isBroadcast(dim.value()))
48 unsigned pos = cast<AffineDimExpr>(dim.value()).getPosition();
53 builder.
create<affine::AffineApplyOp>(loc, map, indices[pos]);
70 static std::optional<SmallVector<int64_t>>
74 if (
options.filterConstraint && failed(
options.filterConstraint(op))) {
75 LDBG(
"--no filter constraint -> BAIL");
79 "vector unrolling expects the native shape or native"
80 "shape call back function to be set");
81 auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
82 if (!unrollableVectorOp) {
83 LDBG(
"--not an unrollable op -> BAIL");
86 auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
87 if (!maybeUnrollShape) {
88 LDBG(
"--could not get shape of op " << *op <<
" -> BAIL");
92 llvm::interleaveComma(*maybeUnrollShape,
DBGS() <<
"--vector op shape: ");
93 llvm::dbgs() <<
"\n";);
95 std::optional<SmallVector<int64_t>> targetShape =
options.nativeShape(op);
97 LDBG(
"--no unrolling target shape defined " << *op <<
"-> SKIP");
100 LLVM_DEBUG(llvm::interleaveComma(*targetShape,
DBGS() <<
"--target shape: ");
101 llvm::dbgs() <<
"\n";);
104 if (!maybeShapeRatio) {
105 LDBG(
"--could not compute integral shape ratio -> BAIL");
108 if (llvm::all_of(*maybeShapeRatio, [](int64_t v) {
return v == 1; })) {
109 LDBG(
"--no unrolling needed -> SKIP");
112 LDBG(
"--found an integral shape ratio to unroll to -> SUCCESS");
120 llvm::to_vector(llvm::seq<int64_t>(0,
static_cast<int64_t
>(numLoops)));
121 if (
options.traversalOrderCallback !=
nullptr) {
122 std::optional<SmallVector<int64_t>> order =
123 options.traversalOrderCallback(op);
125 loopOrder = std::move(*order);
133 struct UnrollTransferReadPattern
141 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
144 if (readOp.getTransferRank() == 0)
146 if (readOp.getMask())
151 auto sourceVectorType = readOp.getVectorType();
158 loc, sourceVectorType, rewriter.
getZeroAttr(sourceVectorType));
162 readOp.getIndices().end());
169 readOp.getPermutationMap(), loc, rewriter);
170 auto slicedRead = rewriter.
create<vector::TransferReadOp>(
171 loc, targetType, readOp.getSource(), indices,
172 readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
173 readOp.getInBoundsAttr());
175 result = rewriter.
create<vector::InsertStridedSliceOp>(
176 loc, slicedRead, result, elementOffsets, strides);
186 struct UnrollTransferWritePattern
194 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
197 if (writeOp.getTransferRank() == 0)
200 if (writeOp.getMask())
205 auto sourceVectorType = writeOp.getVectorType();
210 writeOp.getIndices().end());
216 Value slicedVector = rewriter.
create<vector::ExtractStridedSliceOp>(
217 loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
220 writeOp.getPermutationMap(), loc, rewriter);
222 loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource(),
223 indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
226 resultTensor = slicedWrite->
getResult(0);
229 rewriter.
replaceOp(writeOp, resultTensor);
239 struct OffsetMapInfo {
245 return static_cast<unsigned>(llvm::hash_combine_range(v.begin(), v.end()));
254 struct UnrollContractionPattern
262 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
267 auto dstVecType = cast<VectorType>(contractOp.getResultType());
271 unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
272 AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex];
279 contractOp.getIteratorTypes().size(), contractOp,
options);
286 auto extractOperand = [&](
unsigned index,
Value operand,
292 slicesOperands[index] = rewriter.
create<vector::ExtractStridedSliceOp>(
293 loc, operand, operandOffets, operandShape, operandStrides);
297 AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0];
300 extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);
303 AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1];
306 extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);
308 AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2];
313 auto *accIt = accCache.find(accOffets);
314 if (accIt != accCache.end())
315 slicesOperands[2] = accIt->second;
317 extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets);
321 auto targetType =
VectorType::get(dstShape, dstVecType.getElementType());
323 rewriter, loc, contractOp, slicesOperands, targetType);
329 accCache[dstOffets] = newOp->
getResult(0);
333 loc, dstVecType, rewriter.
getZeroAttr(dstVecType));
334 for (
const auto &it : accCache) {
336 result = rewriter.
create<vector::InsertStridedSliceOp>(
337 loc, it.second, result, it.first, dstStrides);
347 struct UnrollMultiReductionPattern
355 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
357 std::optional<SmallVector<int64_t>> targetShape =
366 Location loc = reductionOp.getLoc();
374 Value slicedOperand = rewriter.
create<vector::ExtractStridedSliceOp>(
375 loc, reductionOp.getSource(), offsets, *targetShape, operandStrides);
376 operands.push_back(slicedOperand);
379 for (
size_t i : llvm::seq(
size_t(0), targetShape->size())) {
380 if (!reductionOp.isReducedDim(i)) {
381 destOffset.push_back(offsets[i]);
382 dstShape.push_back((*targetShape)[i]);
389 auto *accIt = accCache.find(destOffset);
390 if (accIt != accCache.end())
393 acc = rewriter.
create<vector::ExtractStridedSliceOp>(
394 loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
395 operands.push_back(acc);
397 dstShape, reductionOp.getSourceVectorType().getElementType());
399 operands, targetType);
401 accCache[destOffset] = result;
405 loc, reductionOp.getDestType(),
407 for (
const auto &it : accCache) {
409 result = rewriter.
create<vector::InsertStridedSliceOp>(
410 loc, it.second, result, it.first, dstStrides);
427 LogicalResult matchAndRewrite(
Operation *op,
436 *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
440 loc, dstVecType, rewriter.
getZeroAttr(dstVecType));
442 VectorType newVecType =
450 auto vecType = dyn_cast<VectorType>(operand.get().getType());
452 extractOperands.push_back(operand.get());
455 extractOperands.push_back(
456 rewriter.
create<vector::ExtractStridedSliceOp>(
457 loc, operand.get(), offsets, *targetShape, strides));
460 rewriter, loc, op, extractOperands, newVecType);
461 result = rewriter.
create<vector::InsertStridedSliceOp>(
462 loc, newOp->
getResult(0), result, offsets, strides);
472 struct UnrollReductionPattern :
public OpRewritePattern<vector::ReductionOp> {
479 LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
481 std::optional<SmallVector<int64_t>> targetShape =
488 Location loc = reductionOp.getLoc();
489 Value accumulator =
nullptr;
493 Value slicedOperand = rewriter.
create<vector::ExtractStridedSliceOp>(
494 loc, reductionOp.getVector(), offsets, *targetShape, strides);
496 rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
497 Value result = newOp->getResult(0);
501 accumulator = result;
505 accumulator, result);
509 rewriter.
replaceOp(reductionOp, accumulator);
517 struct UnrollTransposePattern :
public OpRewritePattern<vector::TransposeOp> {
524 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
526 if (transposeOp.getResultVectorType().getRank() == 0)
531 auto originalVectorType = transposeOp.getResultVectorType();
533 Location loc = transposeOp.getLoc();
538 loc, originalVectorType, rewriter.
getZeroAttr(originalVectorType));
548 permutedOffsets[indices.value()] = elementOffsets[indices.index()];
549 permutedShape[indices.value()] = (*targetShape)[indices.index()];
551 Value slicedOperand = rewriter.
create<vector::ExtractStridedSliceOp>(
552 loc, transposeOp.getVector(), permutedOffsets, permutedShape,
554 Value transposedSlice =
555 rewriter.
create<vector::TransposeOp>(loc, slicedOperand, permutation);
556 result = rewriter.
create<vector::InsertStridedSliceOp>(
557 loc, transposedSlice, result, elementOffsets, strides);
574 LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
576 VectorType sourceVectorType = gatherOp.getVectorType();
577 if (sourceVectorType.getRank() == 0)
588 loc, sourceVectorType, rewriter.
getZeroAttr(sourceVectorType));
599 Value indexSubVec = rewriter.
create<vector::ExtractStridedSliceOp>(
600 loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides);
601 Value maskSubVec = rewriter.
create<vector::ExtractStridedSliceOp>(
602 loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
603 Value passThruSubVec = rewriter.
create<vector::ExtractStridedSliceOp>(
604 loc, gatherOp.getPassThru(), elementOffsets, *targetShape, strides);
605 auto slicedGather = rewriter.
create<vector::GatherOp>(
606 loc, targetType, gatherOp.getBase(), gatherOp.getIndices(),
607 indexSubVec, maskSubVec, passThruSubVec);
609 result = rewriter.
create<vector::InsertStridedSliceOp>(
610 loc, slicedGather, result, elementOffsets, strides);
625 patterns.
add<UnrollTransferReadPattern, UnrollTransferWritePattern,
626 UnrollContractionPattern, UnrollElementwisePattern,
627 UnrollReductionPattern, UnrollMultiReductionPattern,
628 UnrollTransposePattern, UnrollGatherPattern>(
static llvm::ManagedStatic< PassManagerOptions > options
static SmallVector< Value > sliceTransferIndices(ArrayRef< int64_t > elementOffsets, ArrayRef< Value > indices, AffineMap permutationMap, Location loc, OpBuilder &builder)
Compute the indices of the slice index for a tranfer op.
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
static Operation * cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static SmallVector< int64_t > getUnrollOrder(unsigned numLoops, Operation *op, const vector::UnrollVectorOptions &options)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< OpOperand > getOpOperands()
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePattern is the common base class for all DAG to DAG replacements.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
void populateVectorUnrollPatterns(RewritePatternSet &patterns, const UnrollVectorOptions &options, PatternBenefit benefit=1)
Collect a set of pattern to unroll vector operations to a smaller shapes.
Include the generated interface declarations.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options that control the vector unrolling.