25 #define DEBUG_TYPE "lower-contract-to-arm-neon"
33 static Type matchContainerType(
Type element,
Type container) {
34 if (
auto shapedTy = dyn_cast<ShapedType>(container)) {
35 return shapedTy.clone(element);
45 class LowerContractionToSMMLAPattern
49 LogicalResult matchAndRewrite(vector::ContractionOp op,
54 mlir::VectorType lhsType = op.getLhsType();
55 mlir::VectorType rhsType = op.getRhsType();
57 if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2)
61 if (lhsType.isScalable() || rhsType.isScalable())
63 auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
64 auto dimN = rhsType.getDimSize(0);
65 auto dimK = rhsType.getDimSize(1);
66 bool isVecmat = dimM == 1 ? true :
false;
67 if (lhsType.getDimSize(lhsType.getRank() - 1) !=
68 rhsType.getDimSize(rhsType.getRank() - 1)) {
73 if ((dimM % 2 != 0 && !
isVecmat) || dimN % 2 != 0 || dimK % 8 != 0) {
79 auto iteratorTypes = op.getIteratorTypesArray();
80 if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] !=
81 vector::IteratorType::reduction) {
85 [](vector::IteratorType iteratorType) {
86 return iteratorType != vector::IteratorType::parallel;
92 arith::ExtSIOp origLhsExtOp =
93 dyn_cast_or_null<arith::ExtSIOp>(op.getLhs().getDefiningOp());
94 arith::ExtSIOp origRhsExtOp =
95 dyn_cast_or_null<arith::ExtSIOp>(op.getRhs().getDefiningOp());
96 if (!origLhsExtOp || !origRhsExtOp) {
104 if (
auto lhsExtInType =
105 dyn_cast<mlir::VectorType>(origLhsExtOp.getIn().getType())) {
106 if (lhsExtInType.getElementTypeBitWidth() <= 8) {
107 Type targetLhsExtTy =
108 matchContainerType(rewriter.
getI8Type(), lhsExtInType);
109 extsiLhs = rewriter.
createOrFold<arith::ExtSIOp>(loc, targetLhsExtTy,
110 origLhsExtOp.getIn());
113 if (
auto rhsExtInType =
114 dyn_cast<mlir::VectorType>(origRhsExtOp.getIn().getType())) {
115 if (rhsExtInType.getElementTypeBitWidth() <= 8) {
116 Type targetRhsExtTy =
117 matchContainerType(rewriter.
getI8Type(), rhsExtInType);
118 extsiRhs = rewriter.
createOrFold<arith::ExtSIOp>(loc, targetRhsExtTy,
119 origRhsExtOp.getIn());
123 if (!extsiLhs || !extsiRhs) {
130 loc, op.getResultType(), rewriter.
getZeroAttr(op.getResultType()));
135 if (unrolledSize.size() == 3) {
136 smmlaShape.insert(smmlaShape.begin(),
isVecmat ? 1 : 2);
137 loopOrder.push_back(2);
145 auto extractOperand = [&](
Value operand,
AffineMap permutationMap,
150 return rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
151 loc, operand, operandOffsets, operandShape, operandStrides);
155 AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0];
158 Value tiledLhs = extractOperand(extsiLhs, lhsPermutationMap, lhsOffsets);
159 AffineMap rhsPermutationMap = op.getIndexingMapsArray()[1];
162 Value tiledRhs = extractOperand(extsiRhs, rhsPermutationMap, rhsOffsets);
163 AffineMap accPermutationMap = op.getIndexingMapsArray()[2];
167 extractOperand(op.getAcc(), accPermutationMap, accOffsets);
169 auto inputElementType =
170 cast<ShapedType>(tiledLhs.
getType()).getElementType();
171 auto accElementType =
172 cast<ShapedType>(tiledAcc.
getType()).getElementType();
179 auto expandForSMMLA = [&](
Value tiledOperand,
180 VectorType expandedTypeType) {
181 auto emptyOperand = rewriter.
create<arith::ConstantOp>(
182 loc, expandedTypeType, rewriter.
getZeroAttr(expandedTypeType));
184 cast<ShapedType>(emptyOperand.getType()).getRank(), 0);
186 cast<ShapedType>(tiledOperand.
getType()).getRank(), 1);
187 return rewriter.
createOrFold<vector::InsertStridedSliceOp>(
188 loc, tiledOperand, emptyOperand, offsets, strides);
190 tiledLhs = expandForSMMLA(tiledLhs, inputExpandedType);
191 tiledAcc = expandForSMMLA(tiledAcc, outputExpandedType);
195 auto collapsedInputType =
197 auto collapsedLhs = rewriter.
createOrFold<vector::ShapeCastOp>(
198 tiledLhs.
getLoc(), collapsedInputType, tiledLhs);
199 auto collapsedRhs = rewriter.
createOrFold<vector::ShapeCastOp>(
200 tiledRhs.
getLoc(), collapsedInputType, tiledRhs);
201 auto collapsedOutputType =
204 bool initialKAcc = offsets.back() == 0;
209 collapsedRes = rewriter.
createOrFold<vector::ShapeCastOp>(
210 tiledAcc.
getLoc(), collapsedOutputType, tiledAcc);
215 op.getLoc(), collapsedRes.
getType(), collapsedRes, collapsedLhs,
224 tiledRes = rewriter.
createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
230 cast<ShapedType>(tiledRes.getType()).getRank(), 1);
231 result = rewriter.
createOrFold<vector::InsertStridedSliceOp>(
232 loc, tiledRes, result, accOffsets, strides);
245 patterns.add<LowerContractionToSMMLAPattern>(context, 2);
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
TypedAttr getZeroAttr(Type type)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
void populateLowerContractionToSMMLAPatternPatterns(RewritePatternSet &patterns)
Include the generated interface declarations.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool isVecmat(ArrayAttr indexingMaps)
Tests whether the given maps describe a vector matrix multiplication.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...