19 #include "llvm/ADT/MapVector.h"
20 #include "llvm/ADT/STLExtras.h"
24 #define DEBUG_TYPE "vector-unrolling"
41 class DecomposeShapeIterator {
46 int64_t maxIndexVal{1};
52 :
vectorShape(targetShape.begin(), targetShape.end()),
53 loopOrder(loopOrder.begin(), loopOrder.end()),
54 sliceStrides(originalShape.size()) {
55 assert(originalShape.size() >= targetShape.size());
56 assert(loopOrder.size() == originalShape.size());
60 assert(maybeShapeRatio &&
"Shape does not evenly divide");
68 for (
auto idx : llvm::reverse(loopOrder)) {
69 sliceStrides[idx] = accum;
70 accum *= sliceDimCounts[idx];
80 for (
auto idx : loopOrder) {
81 vectorOffsets[idx] = index / sliceStrides[idx];
82 index %= sliceStrides[idx];
87 int64_t maxIndex()
const {
return maxIndexVal; }
95 return elementOffsets;
109 return constExpr.getValue() == 0;
115 if (isBroadcast(dim.value()))
117 unsigned pos = dim.value().cast<
AffineDimExpr>().getPosition();
121 slicedIndices[pos] = builder.
create<AffineApplyOp>(loc, map, indices[pos]);
123 return slicedIndices;
138 static std::optional<SmallVector<int64_t>>
143 "vector unrolling expects the native shape or native"
144 "shape call back function to be set");
145 auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
146 if (!unrollableVectorOp)
148 auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
149 if (!maybeUnrollShape)
151 std::optional<SmallVector<int64_t>> targetShape =
options.nativeShape(op);
155 if (!maybeShapeRatio ||
156 llvm::all_of(*maybeShapeRatio, [](int64_t v) {
return v == 1; }))
165 llvm::to_vector(llvm::seq<int64_t>(0,
static_cast<int64_t
>(numLoops)));
166 if (
options.traversalOrderCallback !=
nullptr) {
167 std::optional<SmallVector<int64_t>> order =
168 options.traversalOrderCallback(op);
170 loopOrder = std::move(*order);
178 struct UnrollTransferReadPattern
189 if (readOp.getTransferRank() == 0)
191 if (readOp.getMask())
196 auto sourceVectorType = readOp.getVectorType();
203 loc, sourceVectorType, rewriter.
getZeroAttr(sourceVectorType));
205 VectorType::get(*targetShape, sourceVectorType.getElementType());
207 readOp.getIndices().end());
211 DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
213 for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) {
217 readOp.getPermutationMap(), loc, rewriter);
218 auto slicedRead = rewriter.
create<vector::TransferReadOp>(
219 loc, targetType, readOp.getSource(), indices,
220 readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
221 readOp.getInBoundsAttr());
223 result = rewriter.
create<vector::InsertStridedSliceOp>(
224 loc, slicedRead, result, elementOffsets, strides);
234 struct UnrollTransferWritePattern
242 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
245 if (writeOp.getTransferRank() == 0)
248 if (writeOp.getMask())
253 auto sourceVectorType = writeOp.getVectorType();
258 writeOp.getIndices().end());
262 DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
265 for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) {
267 Value slicedVector = rewriter.
create<vector::ExtractStridedSliceOp>(
268 loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
271 writeOp.getPermutationMap(), loc, rewriter);
273 loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource(),
274 indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
277 resultTensor = slicedWrite->
getResult(0);
280 rewriter.
replaceOp(writeOp, resultTensor);
290 struct OffsetMapInfo {
296 return static_cast<unsigned>(llvm::hash_combine_range(v.begin(), v.end()));
305 struct UnrollContractionPattern
313 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
318 auto dstVecType = contractOp.getResultType().cast<VectorType>();
322 unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
323 AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex];
330 contractOp.getIteratorTypes().size(), contractOp,
options);
331 DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
333 const int64_t sliceCount = indexToOffsets.maxIndex();
334 for (int64_t i = 0; i < sliceCount; i++) {
339 auto extractOperand = [&](
unsigned index,
Value operand,
345 slicesOperands[index] = rewriter.
create<vector::ExtractStridedSliceOp>(
346 loc, operand, operandOffets, operandShape, operandStrides);
350 AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0];
353 extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);
355 if (slicesOperands.size() > 3)
356 extractOperand(3, contractOp.getMasks()[0], lhsPermutationMap,
360 AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1];
363 extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);
365 if (slicesOperands.size() > 4)
366 extractOperand(4, contractOp.getMasks()[1], rhsPermutationMap,
369 AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2];
374 auto accIt = accCache.find(accOffets);
375 if (accIt != accCache.end())
376 slicesOperands[2] = accIt->second;
378 extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets);
382 auto targetType = VectorType::get(dstShape, dstVecType.getElementType());
384 rewriter, loc, contractOp, slicesOperands, targetType);
390 accCache[dstOffets] = newOp->
getResult(0);
394 loc, dstVecType, rewriter.
getZeroAttr(dstVecType));
395 for (
const auto &it : accCache) {
397 result = rewriter.
create<vector::InsertStridedSliceOp>(
398 loc, it.second, result, it.first, dstStrides);
408 struct UnrollMultiReductionPattern
416 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
418 std::optional<SmallVector<int64_t>> targetShape =
430 Location loc = reductionOp.getLoc();
435 for (int64_t i = 0; i < sliceCount; i++) {
441 Value slicedOperand = rewriter.
create<vector::ExtractStridedSliceOp>(
442 loc, reductionOp.getSource(), offsets, *targetShape, operandStrides);
443 operands.push_back(slicedOperand);
446 for (
size_t i : llvm::seq(
size_t(0), targetShape->size())) {
447 if (!reductionOp.isReducedDim(i)) {
448 destOffset.push_back(offsets[i]);
449 dstShape.push_back((*targetShape)[i]);
456 auto accIt = accCache.find(destOffset);
457 if (accIt != accCache.end())
460 acc = rewriter.
create<vector::ExtractStridedSliceOp>(
461 loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
462 operands.push_back(acc);
463 auto targetType = VectorType::get(
464 dstShape, reductionOp.getSourceVectorType().getElementType());
466 operands, targetType);
468 accCache[destOffset] = result;
472 loc, reductionOp.getDestType(),
474 for (
const auto &it : accCache) {
476 result = rewriter.
create<vector::InsertStridedSliceOp>(
477 loc, it.second, result, it.first, dstStrides);
503 *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
509 loc, dstVecType, rewriter.
getZeroAttr(dstVecType));
511 VectorType newVecType =
512 VectorType::get(*targetShape, dstVecType.getElementType());
517 for (int64_t i = 0; i < sliceCount; i++) {
522 auto vecType = operand.get().getType().template dyn_cast<VectorType>();
524 extractOperands.push_back(operand.get());
527 extractOperands.push_back(
528 rewriter.
create<vector::ExtractStridedSliceOp>(
529 loc, operand.get(), offsets, *targetShape, strides));
532 rewriter, loc, op, extractOperands, newVecType);
533 result = rewriter.
create<vector::InsertStridedSliceOp>(
534 loc, newOp->
getResult(0), result, offsets, strides);
544 struct UnrollReductionPattern :
public OpRewritePattern<vector::ReductionOp> {
551 LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
553 std::optional<SmallVector<int64_t>> targetShape =
559 int64_t sliceCount = ratio[0];
562 Location loc = reductionOp.getLoc();
563 Value accumulator =
nullptr;
568 for (int64_t i = 0; i < sliceCount; ++i) {
572 Value slicedOperand = rewriter.
create<vector::ExtractStridedSliceOp>(
573 loc, reductionOp.getVector(), offsets, *targetShape, strides);
575 rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
576 Value result = newOp->getResult(0);
580 accumulator = result;
584 accumulator, result);
588 rewriter.
replaceOp(reductionOp, accumulator);
596 struct UnrollTransposePattern :
public OpRewritePattern<vector::TransposeOp> {
603 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
605 if (transposeOp.getResultType().getRank() == 0)
610 auto originalVectorType = transposeOp.getResultType();
612 Location loc = transposeOp.getLoc();
618 loc, originalVectorType, rewriter.
getZeroAttr(originalVectorType));
620 transposeOp.getTransp(permutation);
625 for (int64_t i = 0; i < sliceCount; i++) {
632 permutedOffsets[indices.value()] = elementOffsets[indices.index()];
633 permutedShape[indices.value()] = (*targetShape)[indices.index()];
635 Value slicedOperand = rewriter.
create<vector::ExtractStridedSliceOp>(
636 loc, transposeOp.getVector(), permutedOffsets, permutedShape,
638 Value transposedSlice =
639 rewriter.
create<vector::TransposeOp>(loc, slicedOperand, permutation);
640 result = rewriter.
create<vector::InsertStridedSliceOp>(
641 loc, transposedSlice, result, elementOffsets, strides);
656 patterns.
add<UnrollTransferReadPattern, UnrollTransferWritePattern,
657 UnrollContractionPattern, UnrollElementwisePattern,
658 UnrollReductionPattern, UnrollMultiReductionPattern,
static llvm::ManagedStatic< PassManagerOptions > options
static ArrayRef< int64_t > vectorShape(Type type)
static SmallVector< Value > sliceTransferIndices(ArrayRef< int64_t > elementOffsets, ArrayRef< Value > indices, AffineMap permutationMap, Location loc, OpBuilder &builder)
Compute the indices of the slice index for a tranfer op.
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
static Operation * cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static SmallVector< int64_t > getUnrollOrder(unsigned numLoops, Operation *op, const vector::UnrollVectorOptions &options)
static SmallVector< int64_t > getVectorOffset(ArrayRef< int64_t > ratioStrides, int64_t index, ArrayRef< int64_t > targetShape)
During unrolling from originalShape to targetShape return the offset for the slice index.
An integer constant appearing in affine expression.
A dimensional identifier appearing in an affine expression.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
MLIRContext * getContext() const
Attribute getZeroAttr(Type type)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< OpOperand > getOpOperands()
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePattern is the common base class for all DAG to DAG replacements.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value v2)
Return the result value of reducing two scalar/vector values with the corresponding arith operation.
void populateVectorUnrollPatterns(RewritePatternSet &patterns, const UnrollVectorOptions &options, PatternBenefit benefit=1)
Collect a set of pattern to unroll vector operations to a smaller shapes.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
SmallVector< int64_t > computeElementwiseMul(ArrayRef< int64_t > v1, ArrayRef< int64_t > v2)
Return a vector containing llvm::zip of v1 and v2 multiplied elementwise.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
Given a set of sizes, compute and return the strides (i.e.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
SmallVector< int64_t > delinearize(ArrayRef< int64_t > strides, int64_t linearIndex)
Given the strides together with a linear index in the dimension space, returns the vector-space offse...
int64_t computeMaxLinearIndex(ArrayRef< int64_t > basis)
Return the number of elements of basis (i.e.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Compute and return the multi-dimensional integral ratio of subShape to the trailing dimensions of sha...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options that control the vector unrolling.