19 #include "llvm/ADT/MapVector.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/Support/Debug.h"
25 #define DEBUG_TYPE "vector-unroll"
26 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
27 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
40 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr))
41 return constExpr.getValue() == 0;
47 if (isBroadcast(dim.value()))
49 unsigned pos = cast<AffineDimExpr>(dim.value()).getPosition();
54 builder.
create<affine::AffineApplyOp>(loc, map, indices[pos]);
71 static std::optional<SmallVector<int64_t>>
76 LDBG(
"--no filter constraint -> BAIL");
80 "vector unrolling expects the native shape or native"
81 "shape call back function to be set");
82 auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
83 if (!unrollableVectorOp) {
84 LDBG(
"--not an unrollable op -> BAIL");
87 auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
88 if (!maybeUnrollShape) {
89 LDBG(
"--could not get shape of op " << *op <<
" -> BAIL");
93 llvm::interleaveComma(*maybeUnrollShape,
DBGS() <<
"--vector op shape: ");
94 llvm::dbgs() <<
"\n";);
96 std::optional<SmallVector<int64_t>> targetShape =
options.nativeShape(op);
98 LDBG(
"--no unrolling target shape defined " << *op <<
"-> SKIP");
101 LLVM_DEBUG(llvm::interleaveComma(*targetShape,
DBGS() <<
"--target shape: ");
102 llvm::dbgs() <<
"\n";);
105 if (!maybeShapeRatio) {
106 LDBG(
"--could not compute integral shape ratio -> BAIL");
109 if (llvm::all_of(*maybeShapeRatio, [](int64_t v) {
return v == 1; })) {
110 LDBG(
"--no unrolling needed -> SKIP");
113 LDBG(
"--found an integral shape ratio to unroll to -> SUCCESS");
121 llvm::to_vector(llvm::seq<int64_t>(0,
static_cast<int64_t
>(numLoops)));
122 if (
options.traversalOrderCallback !=
nullptr) {
123 std::optional<SmallVector<int64_t>> order =
124 options.traversalOrderCallback(op);
126 loopOrder = std::move(*order);
134 struct UnrollTransferReadPattern
145 if (readOp.getTransferRank() == 0)
147 if (readOp.getMask())
152 auto sourceVectorType = readOp.getVectorType();
159 loc, sourceVectorType, rewriter.
getZeroAttr(sourceVectorType));
163 readOp.getIndices().end());
170 readOp.getPermutationMap(), loc, rewriter);
171 auto slicedRead = rewriter.
create<vector::TransferReadOp>(
172 loc, targetType, readOp.getSource(), indices,
173 readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
174 readOp.getInBoundsAttr());
176 result = rewriter.
create<vector::InsertStridedSliceOp>(
177 loc, slicedRead, result, elementOffsets, strides);
187 struct UnrollTransferWritePattern
195 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
198 if (writeOp.getTransferRank() == 0)
201 if (writeOp.getMask())
206 auto sourceVectorType = writeOp.getVectorType();
211 writeOp.getIndices().end());
217 Value slicedVector = rewriter.
create<vector::ExtractStridedSliceOp>(
218 loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
221 writeOp.getPermutationMap(), loc, rewriter);
223 loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource(),
224 indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
227 resultTensor = slicedWrite->
getResult(0);
230 rewriter.
replaceOp(writeOp, resultTensor);
240 struct OffsetMapInfo {
246 return static_cast<unsigned>(llvm::hash_combine_range(v.begin(), v.end()));
255 struct UnrollContractionPattern
263 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
268 auto dstVecType = cast<VectorType>(contractOp.getResultType());
272 unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
273 AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex];
280 contractOp.getIteratorTypes().size(), contractOp,
options);
287 auto extractOperand = [&](
unsigned index,
Value operand,
293 slicesOperands[index] = rewriter.
create<vector::ExtractStridedSliceOp>(
294 loc, operand, operandOffets, operandShape, operandStrides);
298 AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0];
301 extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);
304 AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1];
307 extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);
309 AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2];
314 auto accIt = accCache.find(accOffets);
315 if (accIt != accCache.end())
316 slicesOperands[2] = accIt->second;
318 extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets);
322 auto targetType =
VectorType::get(dstShape, dstVecType.getElementType());
324 rewriter, loc, contractOp, slicesOperands, targetType);
330 accCache[dstOffets] = newOp->
getResult(0);
334 loc, dstVecType, rewriter.
getZeroAttr(dstVecType));
335 for (
const auto &it : accCache) {
337 result = rewriter.
create<vector::InsertStridedSliceOp>(
338 loc, it.second, result, it.first, dstStrides);
348 struct UnrollMultiReductionPattern
356 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
358 std::optional<SmallVector<int64_t>> targetShape =
367 Location loc = reductionOp.getLoc();
375 Value slicedOperand = rewriter.
create<vector::ExtractStridedSliceOp>(
376 loc, reductionOp.getSource(), offsets, *targetShape, operandStrides);
377 operands.push_back(slicedOperand);
380 for (
size_t i : llvm::seq(
size_t(0), targetShape->size())) {
381 if (!reductionOp.isReducedDim(i)) {
382 destOffset.push_back(offsets[i]);
383 dstShape.push_back((*targetShape)[i]);
390 auto accIt = accCache.find(destOffset);
391 if (accIt != accCache.end())
394 acc = rewriter.
create<vector::ExtractStridedSliceOp>(
395 loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
396 operands.push_back(acc);
398 dstShape, reductionOp.getSourceVectorType().getElementType());
400 operands, targetType);
402 accCache[destOffset] = result;
406 loc, reductionOp.getDestType(),
408 for (
const auto &it : accCache) {
410 result = rewriter.
create<vector::InsertStridedSliceOp>(
411 loc, it.second, result, it.first, dstStrides);
437 *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
441 loc, dstVecType, rewriter.
getZeroAttr(dstVecType));
443 VectorType newVecType =
451 auto vecType = dyn_cast<VectorType>(operand.get().getType());
453 extractOperands.push_back(operand.get());
456 extractOperands.push_back(
457 rewriter.
create<vector::ExtractStridedSliceOp>(
458 loc, operand.get(), offsets, *targetShape, strides));
461 rewriter, loc, op, extractOperands, newVecType);
462 result = rewriter.
create<vector::InsertStridedSliceOp>(
463 loc, newOp->
getResult(0), result, offsets, strides);
473 struct UnrollReductionPattern :
public OpRewritePattern<vector::ReductionOp> {
480 LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
482 std::optional<SmallVector<int64_t>> targetShape =
489 Location loc = reductionOp.getLoc();
490 Value accumulator =
nullptr;
494 Value slicedOperand = rewriter.
create<vector::ExtractStridedSliceOp>(
495 loc, reductionOp.getVector(), offsets, *targetShape, strides);
497 rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
498 Value result = newOp->getResult(0);
502 accumulator = result;
506 accumulator, result);
510 rewriter.
replaceOp(reductionOp, accumulator);
518 struct UnrollTransposePattern :
public OpRewritePattern<vector::TransposeOp> {
525 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
527 if (transposeOp.getResultVectorType().getRank() == 0)
532 auto originalVectorType = transposeOp.getResultVectorType();
534 Location loc = transposeOp.getLoc();
539 loc, originalVectorType, rewriter.
getZeroAttr(originalVectorType));
549 permutedOffsets[indices.value()] = elementOffsets[indices.index()];
550 permutedShape[indices.value()] = (*targetShape)[indices.index()];
552 Value slicedOperand = rewriter.
create<vector::ExtractStridedSliceOp>(
553 loc, transposeOp.getVector(), permutedOffsets, permutedShape,
555 Value transposedSlice =
556 rewriter.
create<vector::TransposeOp>(loc, slicedOperand, permutation);
557 result = rewriter.
create<vector::InsertStridedSliceOp>(
558 loc, transposedSlice, result, elementOffsets, strides);
577 VectorType sourceVectorType = gatherOp.getVectorType();
578 if (sourceVectorType.getRank() == 0)
589 loc, sourceVectorType, rewriter.
getZeroAttr(sourceVectorType));
600 Value indexSubVec = rewriter.
create<vector::ExtractStridedSliceOp>(
601 loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides);
602 Value maskSubVec = rewriter.
create<vector::ExtractStridedSliceOp>(
603 loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
604 Value passThruSubVec = rewriter.
create<vector::ExtractStridedSliceOp>(
605 loc, gatherOp.getPassThru(), elementOffsets, *targetShape, strides);
606 auto slicedGather = rewriter.
create<vector::GatherOp>(
607 loc, targetType, gatherOp.getBase(), gatherOp.getIndices(),
608 indexSubVec, maskSubVec, passThruSubVec);
610 result = rewriter.
create<vector::InsertStridedSliceOp>(
611 loc, slicedGather, result, elementOffsets, strides);
626 patterns.
add<UnrollTransferReadPattern, UnrollTransferWritePattern,
627 UnrollContractionPattern, UnrollElementwisePattern,
628 UnrollReductionPattern, UnrollMultiReductionPattern,
629 UnrollTransposePattern, UnrollGatherPattern>(
static llvm::ManagedStatic< PassManagerOptions > options
static SmallVector< Value > sliceTransferIndices(ArrayRef< int64_t > elementOffsets, ArrayRef< Value > indices, AffineMap permutationMap, Location loc, OpBuilder &builder)
Compute the indices of the slice index for a tranfer op.
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
static Operation * cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static SmallVector< int64_t > getUnrollOrder(unsigned numLoops, Operation *op, const vector::UnrollVectorOptions &options)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< OpOperand > getOpOperands()
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePattern is the common base class for all DAG to DAG replacements.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, Value mask=Value())
Return the result value of reducing two scalar/vector values with the corresponding arith operation.
void populateVectorUnrollPatterns(RewritePatternSet &patterns, const UnrollVectorOptions &options, PatternBenefit benefit=1)
Collect a set of pattern to unroll vector operations to a smaller shapes.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options that control the vector unrolling.