25 #define DEBUG_TYPE "lower-contract-to-arm-neon"
33 static Type matchContainerType(
Type element,
Type container) {
34 if (
auto shapedTy = dyn_cast<ShapedType>(container)) {
35 return shapedTy.clone(element);
45 class LowerContractionToSMMLAPattern
49 LogicalResult matchAndRewrite(vector::ContractionOp op,
54 mlir::VectorType lhsType = op.getLhsType();
55 mlir::VectorType rhsType = op.getRhsType();
57 if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2)
59 auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
60 auto dimN = rhsType.getDimSize(0);
61 auto dimK = rhsType.getDimSize(1);
62 bool isVecmat = dimM == 1 ? true :
false;
63 if (lhsType.getDimSize(lhsType.getRank() - 1) !=
64 rhsType.getDimSize(rhsType.getRank() - 1)) {
69 if ((dimM % 2 != 0 && !
isVecmat) || dimN % 2 != 0 || dimK % 8 != 0) {
75 auto iteratorTypes = op.getIteratorTypesArray();
76 if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] !=
77 vector::IteratorType::reduction) {
81 [](vector::IteratorType iteratorType) {
82 return iteratorType != vector::IteratorType::parallel;
88 arith::ExtSIOp origLhsExtOp =
89 dyn_cast_or_null<arith::ExtSIOp>(op.getLhs().getDefiningOp());
90 arith::ExtSIOp origRhsExtOp =
91 dyn_cast_or_null<arith::ExtSIOp>(op.getRhs().getDefiningOp());
92 if (!origLhsExtOp || !origRhsExtOp) {
100 if (
auto lhsExtInType =
101 dyn_cast<mlir::VectorType>(origLhsExtOp.getIn().getType())) {
102 if (lhsExtInType.getElementTypeBitWidth() <= 8) {
103 Type targetLhsExtTy =
104 matchContainerType(rewriter.
getI8Type(), lhsExtInType);
105 extsiLhs = rewriter.
createOrFold<arith::ExtSIOp>(loc, targetLhsExtTy,
106 origLhsExtOp.getIn());
109 if (
auto rhsExtInType =
110 dyn_cast<mlir::VectorType>(origRhsExtOp.getIn().getType())) {
111 if (rhsExtInType.getElementTypeBitWidth() <= 8) {
112 Type targetRhsExtTy =
113 matchContainerType(rewriter.
getI8Type(), rhsExtInType);
114 extsiRhs = rewriter.
createOrFold<arith::ExtSIOp>(loc, targetRhsExtTy,
115 origRhsExtOp.getIn());
119 if (!extsiLhs || !extsiRhs) {
126 loc, op.getResultType(), rewriter.
getZeroAttr(op.getResultType()));
131 if (unrolledSize.size() == 3) {
132 smmlaShape.insert(smmlaShape.begin(),
isVecmat ? 1 : 2);
133 loopOrder.push_back(2);
141 auto extractOperand = [&](
Value operand,
AffineMap permutationMap,
146 return rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
147 loc, operand, operandOffsets, operandShape, operandStrides);
151 AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0];
154 Value tiledLhs = extractOperand(extsiLhs, lhsPermutationMap, lhsOffsets);
155 AffineMap rhsPermutationMap = op.getIndexingMapsArray()[1];
158 Value tiledRhs = extractOperand(extsiRhs, rhsPermutationMap, rhsOffsets);
159 AffineMap accPermutationMap = op.getIndexingMapsArray()[2];
163 extractOperand(op.getAcc(), accPermutationMap, accOffsets);
165 auto inputElementType =
166 cast<ShapedType>(tiledLhs.
getType()).getElementType();
167 auto accElementType =
168 cast<ShapedType>(tiledAcc.
getType()).getElementType();
175 auto expandForSMMLA = [&](
Value tiledOperand,
176 VectorType expandedTypeType) {
177 auto emptyOperand = rewriter.
create<arith::ConstantOp>(
178 loc, expandedTypeType, rewriter.
getZeroAttr(expandedTypeType));
180 cast<ShapedType>(emptyOperand.getType()).getRank(), 0);
182 cast<ShapedType>(tiledOperand.
getType()).getRank(), 1);
183 return rewriter.
createOrFold<vector::InsertStridedSliceOp>(
184 loc, tiledOperand, emptyOperand, offsets, strides);
186 tiledLhs = expandForSMMLA(tiledLhs, inputExpandedType);
187 tiledAcc = expandForSMMLA(tiledAcc, outputExpandedType);
191 auto collapsedInputType =
193 auto collapsedLhs = rewriter.
createOrFold<vector::ShapeCastOp>(
194 tiledLhs.
getLoc(), collapsedInputType, tiledLhs);
195 auto collapsedRhs = rewriter.
createOrFold<vector::ShapeCastOp>(
196 tiledRhs.
getLoc(), collapsedInputType, tiledRhs);
197 auto collapsedOutputType =
200 bool initialKAcc = offsets.back() == 0;
205 collapsedRes = rewriter.
createOrFold<vector::ShapeCastOp>(
206 tiledAcc.
getLoc(), collapsedOutputType, tiledAcc);
211 op.getLoc(), collapsedRes.
getType(), collapsedRes, collapsedLhs,
220 tiledRes = rewriter.
createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
226 cast<ShapedType>(tiledRes.getType()).getRank(), 1);
227 result = rewriter.
createOrFold<vector::InsertStridedSliceOp>(
228 loc, tiledRes, result, accOffsets, strides);
241 patterns.
add<LowerContractionToSMMLAPattern>(context, 1);
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
TypedAttr getZeroAttr(Type type)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
void populateLowerContractionToSMMLAPatternPatterns(RewritePatternSet &patterns)
Include the generated interface declarations.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool isVecmat(ArrayAttr indexingMaps)
Tests whether the given maps describe a vector matrix multiplication.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...