26 #define DEBUG_TYPE "lower-contract-to-arm-neon"
34 static Type matchContainerType(
Type element,
Type container) {
35 if (
auto shapedTy = dyn_cast<ShapedType>(container)) {
36 return shapedTy.clone(element);
46 class LowerContractionToSMMLAPattern
55 mlir::VectorType lhsType = op.getLhsType();
56 mlir::VectorType rhsType = op.getRhsType();
58 if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2)
60 auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
61 auto dimN = rhsType.getDimSize(0);
62 auto dimK = rhsType.getDimSize(1);
63 bool isVecmat = dimM == 1 ? true :
false;
64 if (lhsType.getDimSize(lhsType.getRank() - 1) !=
65 rhsType.getDimSize(rhsType.getRank() - 1)) {
70 if ((dimM % 2 != 0 && !
isVecmat) || dimN % 2 != 0 || dimK % 8 != 0) {
76 auto iteratorTypes = op.getIteratorTypesArray();
77 if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] !=
78 vector::IteratorType::reduction) {
82 [](vector::IteratorType iteratorType) {
83 return iteratorType != vector::IteratorType::parallel;
89 arith::ExtSIOp origLhsExtOp =
90 dyn_cast_or_null<arith::ExtSIOp>(op.getLhs().getDefiningOp());
91 arith::ExtSIOp origRhsExtOp =
92 dyn_cast_or_null<arith::ExtSIOp>(op.getRhs().getDefiningOp());
93 if (!origLhsExtOp || !origRhsExtOp) {
101 if (
auto lhsExtInType =
102 dyn_cast<mlir::VectorType>(origLhsExtOp.getIn().getType())) {
103 if (lhsExtInType.getElementTypeBitWidth() <= 8) {
104 Type targetLhsExtTy =
105 matchContainerType(rewriter.
getI8Type(), lhsExtInType);
106 extsiLhs = rewriter.
createOrFold<arith::ExtSIOp>(loc, targetLhsExtTy,
107 origLhsExtOp.getIn());
110 if (
auto rhsExtInType =
111 dyn_cast<mlir::VectorType>(origRhsExtOp.getIn().getType())) {
112 if (rhsExtInType.getElementTypeBitWidth() <= 8) {
113 Type targetRhsExtTy =
114 matchContainerType(rewriter.
getI8Type(), rhsExtInType);
115 extsiRhs = rewriter.
createOrFold<arith::ExtSIOp>(loc, targetRhsExtTy,
116 origRhsExtOp.getIn());
120 if (!extsiLhs || !extsiRhs) {
127 loc, op.getResultType(), rewriter.
getZeroAttr(op.getResultType()));
132 if (unrolledSize.size() == 3) {
133 smmlaShape.insert(smmlaShape.begin(),
isVecmat ? 1 : 2);
134 loopOrder.push_back(2);
142 auto extractOperand = [&](
Value operand,
AffineMap permutationMap,
147 return rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
148 loc, operand, operandOffsets, operandShape, operandStrides);
152 AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0];
155 Value tiledLhs = extractOperand(extsiLhs, lhsPermutationMap, lhsOffsets);
156 AffineMap rhsPermutationMap = op.getIndexingMapsArray()[1];
159 Value tiledRhs = extractOperand(extsiRhs, rhsPermutationMap, rhsOffsets);
160 AffineMap accPermutationMap = op.getIndexingMapsArray()[2];
164 extractOperand(op.getAcc(), accPermutationMap, accOffsets);
166 auto inputElementType =
167 cast<ShapedType>(tiledLhs.
getType()).getElementType();
168 auto accElementType =
169 cast<ShapedType>(tiledAcc.
getType()).getElementType();
176 auto expandForSMMLA = [&](
Value tiledOperand,
177 VectorType expandedTypeType) {
178 auto emptyOperand = rewriter.
create<arith::ConstantOp>(
179 loc, expandedTypeType, rewriter.
getZeroAttr(expandedTypeType));
181 cast<ShapedType>(emptyOperand.getType()).getRank(), 0);
183 cast<ShapedType>(tiledOperand.
getType()).getRank(), 1);
184 return rewriter.
createOrFold<vector::InsertStridedSliceOp>(
185 loc, tiledOperand, emptyOperand, offsets, strides);
187 tiledLhs = expandForSMMLA(tiledLhs, inputExpandedType);
188 tiledAcc = expandForSMMLA(tiledAcc, outputExpandedType);
192 auto collapsedInputType =
194 auto collapsedLhs = rewriter.
createOrFold<vector::ShapeCastOp>(
195 tiledLhs.
getLoc(), collapsedInputType, tiledLhs);
196 auto collapsedRhs = rewriter.
createOrFold<vector::ShapeCastOp>(
197 tiledRhs.
getLoc(), collapsedInputType, tiledRhs);
198 auto collapsedOutputType =
201 bool initialKAcc = offsets.back() == 0;
206 collapsedRes = rewriter.
createOrFold<vector::ShapeCastOp>(
207 tiledAcc.
getLoc(), collapsedOutputType, tiledAcc);
212 op.
getLoc(), collapsedRes.
getType(), collapsedRes, collapsedLhs,
221 tiledRes = rewriter.
createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
227 cast<ShapedType>(tiledRes.getType()).getRank(), 1);
228 result = rewriter.
createOrFold<vector::InsertStridedSliceOp>(
229 loc, tiledRes, result, accOffsets, strides);
242 patterns.
add<LowerContractionToSMMLAPattern>(context, 1);
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
TypedAttr getZeroAttr(Type type)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Location getLoc()
The source location the operation was defined or derived from.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
void populateLowerContractionToSMMLAPatternPatterns(RewritePatternSet &patterns)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool isVecmat(ArrayAttr indexingMaps)
Tests whether the given maps describe a vector matrix multiplication.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...