19 #include "llvm/ADT/MapVector.h"
20 #include "llvm/ADT/STLExtras.h"
24 #define DEBUG_TYPE "vector-unrolling"
41 class DecomposeShapeIterator {
46 int64_t maxIndexVal{1};
52 :
vectorShape(targetShape.begin(), targetShape.end()),
53 loopOrder(loopOrder.begin(), loopOrder.end()),
54 sliceStrides(originalShape.size()) {
55 assert(originalShape.size() >= targetShape.size());
56 assert(loopOrder.size() == originalShape.size());
60 assert(maybeShapeRatio &&
"Shape does not evenly divide");
68 for (
auto idx : llvm::reverse(loopOrder)) {
69 sliceStrides[idx] = accum;
70 accum *= sliceDimCounts[idx];
80 for (
auto idx : loopOrder) {
81 vectorOffsets[idx] = index / sliceStrides[idx];
82 index %= sliceStrides[idx];
87 int64_t maxIndex()
const {
return maxIndexVal; }
95 return elementOffsets;
109 return constExpr.getValue() == 0;
115 if (isBroadcast(dim.value()))
117 unsigned pos = dim.value().cast<
AffineDimExpr>().getPosition();
122 builder.
create<affine::AffineApplyOp>(loc, map, indices[pos]);
124 return slicedIndices;
139 static std::optional<SmallVector<int64_t>>
144 "vector unrolling expects the native shape or native"
145 "shape call back function to be set");
146 auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
147 if (!unrollableVectorOp)
149 auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
150 if (!maybeUnrollShape)
152 std::optional<SmallVector<int64_t>> targetShape =
options.nativeShape(op);
156 if (!maybeShapeRatio ||
157 llvm::all_of(*maybeShapeRatio, [](int64_t v) {
return v == 1; }))
166 llvm::to_vector(llvm::seq<int64_t>(0,
static_cast<int64_t
>(numLoops)));
167 if (
options.traversalOrderCallback !=
nullptr) {
168 std::optional<SmallVector<int64_t>> order =
169 options.traversalOrderCallback(op);
171 loopOrder = std::move(*order);
179 struct UnrollTransferReadPattern
190 if (readOp.getTransferRank() == 0)
192 if (readOp.getMask())
197 auto sourceVectorType = readOp.getVectorType();
204 loc, sourceVectorType, rewriter.
getZeroAttr(sourceVectorType));
208 readOp.getIndices().end());
212 DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
214 for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) {
218 readOp.getPermutationMap(), loc, rewriter);
219 auto slicedRead = rewriter.
create<vector::TransferReadOp>(
220 loc, targetType, readOp.getSource(), indices,
221 readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
222 readOp.getInBoundsAttr());
224 result = rewriter.
create<vector::InsertStridedSliceOp>(
225 loc, slicedRead, result, elementOffsets, strides);
235 struct UnrollTransferWritePattern
243 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
246 if (writeOp.getTransferRank() == 0)
249 if (writeOp.getMask())
254 auto sourceVectorType = writeOp.getVectorType();
259 writeOp.getIndices().end());
263 DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
266 for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) {
268 Value slicedVector = rewriter.
create<vector::ExtractStridedSliceOp>(
269 loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
272 writeOp.getPermutationMap(), loc, rewriter);
274 loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource(),
275 indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
278 resultTensor = slicedWrite->
getResult(0);
281 rewriter.
replaceOp(writeOp, resultTensor);
291 struct OffsetMapInfo {
297 return static_cast<unsigned>(llvm::hash_combine_range(v.begin(), v.end()));
306 struct UnrollContractionPattern
314 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
319 auto dstVecType = cast<VectorType>(contractOp.getResultType());
323 unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
324 AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex];
331 contractOp.getIteratorTypes().size(), contractOp,
options);
332 DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
334 const int64_t sliceCount = indexToOffsets.maxIndex();
335 for (int64_t i = 0; i < sliceCount; i++) {
340 auto extractOperand = [&](
unsigned index,
Value operand,
346 slicesOperands[index] = rewriter.
create<vector::ExtractStridedSliceOp>(
347 loc, operand, operandOffets, operandShape, operandStrides);
351 AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0];
354 extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);
357 AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1];
360 extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);
362 AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2];
367 auto accIt = accCache.find(accOffets);
368 if (accIt != accCache.end())
369 slicesOperands[2] = accIt->second;
371 extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets);
375 auto targetType =
VectorType::get(dstShape, dstVecType.getElementType());
377 rewriter, loc, contractOp, slicesOperands, targetType);
383 accCache[dstOffets] = newOp->
getResult(0);
387 loc, dstVecType, rewriter.
getZeroAttr(dstVecType));
388 for (
const auto &it : accCache) {
390 result = rewriter.
create<vector::InsertStridedSliceOp>(
391 loc, it.second, result, it.first, dstStrides);
401 struct UnrollMultiReductionPattern
409 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
411 std::optional<SmallVector<int64_t>> targetShape =
423 Location loc = reductionOp.getLoc();
428 for (int64_t i = 0; i < sliceCount; i++) {
434 Value slicedOperand = rewriter.
create<vector::ExtractStridedSliceOp>(
435 loc, reductionOp.getSource(), offsets, *targetShape, operandStrides);
436 operands.push_back(slicedOperand);
439 for (
size_t i : llvm::seq(
size_t(0), targetShape->size())) {
440 if (!reductionOp.isReducedDim(i)) {
441 destOffset.push_back(offsets[i]);
442 dstShape.push_back((*targetShape)[i]);
449 auto accIt = accCache.find(destOffset);
450 if (accIt != accCache.end())
453 acc = rewriter.
create<vector::ExtractStridedSliceOp>(
454 loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
455 operands.push_back(acc);
457 dstShape, reductionOp.getSourceVectorType().getElementType());
459 operands, targetType);
461 accCache[destOffset] = result;
465 loc, reductionOp.getDestType(),
467 for (
const auto &it : accCache) {
469 result = rewriter.
create<vector::InsertStridedSliceOp>(
470 loc, it.second, result, it.first, dstStrides);
496 *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
502 loc, dstVecType, rewriter.
getZeroAttr(dstVecType));
504 VectorType newVecType =
510 for (int64_t i = 0; i < sliceCount; i++) {
515 auto vecType = dyn_cast<VectorType>(operand.get().getType());
517 extractOperands.push_back(operand.get());
520 extractOperands.push_back(
521 rewriter.
create<vector::ExtractStridedSliceOp>(
522 loc, operand.get(), offsets, *targetShape, strides));
525 rewriter, loc, op, extractOperands, newVecType);
526 result = rewriter.
create<vector::InsertStridedSliceOp>(
527 loc, newOp->
getResult(0), result, offsets, strides);
537 struct UnrollReductionPattern :
public OpRewritePattern<vector::ReductionOp> {
544 LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
546 std::optional<SmallVector<int64_t>> targetShape =
552 int64_t sliceCount = ratio[0];
555 Location loc = reductionOp.getLoc();
556 Value accumulator =
nullptr;
561 for (int64_t i = 0; i < sliceCount; ++i) {
565 Value slicedOperand = rewriter.
create<vector::ExtractStridedSliceOp>(
566 loc, reductionOp.getVector(), offsets, *targetShape, strides);
568 rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
569 Value result = newOp->getResult(0);
573 accumulator = result;
577 accumulator, result);
581 rewriter.
replaceOp(reductionOp, accumulator);
589 struct UnrollTransposePattern :
public OpRewritePattern<vector::TransposeOp> {
596 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
598 if (transposeOp.getResultVectorType().getRank() == 0)
603 auto originalVectorType = transposeOp.getResultVectorType();
605 Location loc = transposeOp.getLoc();
611 loc, originalVectorType, rewriter.
getZeroAttr(originalVectorType));
613 transposeOp.getTransp(permutation);
618 for (int64_t i = 0; i < sliceCount; i++) {
625 permutedOffsets[indices.value()] = elementOffsets[indices.index()];
626 permutedShape[indices.value()] = (*targetShape)[indices.index()];
628 Value slicedOperand = rewriter.
create<vector::ExtractStridedSliceOp>(
629 loc, transposeOp.getVector(), permutedOffsets, permutedShape,
631 Value transposedSlice =
632 rewriter.
create<vector::TransposeOp>(loc, slicedOperand, permutation);
633 result = rewriter.
create<vector::InsertStridedSliceOp>(
634 loc, transposedSlice, result, elementOffsets, strides);
653 VectorType sourceVectorType = gatherOp.getVectorType();
654 if (sourceVectorType.getRank() == 0)
665 loc, sourceVectorType, rewriter.
getZeroAttr(sourceVectorType));
671 DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
673 for (int64_t i = 0, e = indexToOffsets.maxIndex(); i < e; ++i) {
678 Value indexSubVec = rewriter.
create<vector::ExtractStridedSliceOp>(
679 loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides);
680 Value maskSubVec = rewriter.
create<vector::ExtractStridedSliceOp>(
681 loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
682 Value passThruSubVec = rewriter.
create<vector::ExtractStridedSliceOp>(
683 loc, gatherOp.getPassThru(), elementOffsets, *targetShape, strides);
684 auto slicedGather = rewriter.
create<vector::GatherOp>(
685 loc, targetType, gatherOp.getBase(), gatherOp.getIndices(),
686 indexSubVec, maskSubVec, passThruSubVec);
688 result = rewriter.
create<vector::InsertStridedSliceOp>(
689 loc, slicedGather, result, elementOffsets, strides);
704 patterns.
add<UnrollTransferReadPattern, UnrollTransferWritePattern,
705 UnrollContractionPattern, UnrollElementwisePattern,
706 UnrollReductionPattern, UnrollMultiReductionPattern,
707 UnrollTransposePattern, UnrollGatherPattern>(
static llvm::ManagedStatic< PassManagerOptions > options
static ArrayRef< int64_t > vectorShape(Type type)
static SmallVector< Value > sliceTransferIndices(ArrayRef< int64_t > elementOffsets, ArrayRef< Value > indices, AffineMap permutationMap, Location loc, OpBuilder &builder)
Compute the indices of the slice index for a tranfer op.
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
static Operation * cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static SmallVector< int64_t > getUnrollOrder(unsigned numLoops, Operation *op, const vector::UnrollVectorOptions &options)
static SmallVector< int64_t > getVectorOffset(ArrayRef< int64_t > ratioStrides, int64_t index, ArrayRef< int64_t > targetShape)
During unrolling from originalShape to targetShape return the offset for the slice index.
An integer constant appearing in affine expression.
A dimensional identifier appearing in an affine expression.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< OpOperand > getOpOperands()
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePattern is the common base class for all DAG to DAG replacements.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, Value mask=Value())
Return the result value of reducing two scalar/vector values with the corresponding arith operation.
void populateVectorUnrollPatterns(RewritePatternSet &patterns, const UnrollVectorOptions &options, PatternBenefit benefit=1)
Collect a set of pattern to unroll vector operations to a smaller shapes.
This header declares functions that assit transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
SmallVector< int64_t > computeElementwiseMul(ArrayRef< int64_t > v1, ArrayRef< int64_t > v2)
Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
int64_t computeMaxLinearIndex(ArrayRef< int64_t > basis)
Return the number of elements of basis (i.e.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options that control the vector unrolling.