MLIR  21.0.0git
LowerContractionToSMMLAPattern.cpp
Go to the documentation of this file.
1 //===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering patterns from vector.contract to
10 // arm_neon.intr.smmla
11 //
12 //===---
13 
21 #include "mlir/IR/AffineMap.h"
22 #include "mlir/IR/PatternMatch.h"
24 
25 #define DEBUG_TYPE "lower-contract-to-arm-neon"
26 
27 using namespace mlir;
28 using namespace mlir::arm_neon;
29 
30 namespace {
31 
32 /// Return the shaped type with new element type.
33 static Type matchContainerType(Type element, Type container) {
34  if (auto shapedTy = dyn_cast<ShapedType>(container)) {
35  return shapedTy.clone(element);
36  }
37  return element;
38 }
39 
40 /// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
41 /// any vector.contract into multiple smmla instructions with unrolling so long
42 /// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
43 /// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is
44 /// necessary, a single smmla instruction is emitted.
45 class LowerContractionToSMMLAPattern
46  : public OpRewritePattern<vector::ContractionOp> {
47 public:
49  LogicalResult matchAndRewrite(vector::ContractionOp op,
50  PatternRewriter &rewriter) const override {
51  Location loc = op.getLoc();
52  // Infer tile sizes from operands. For vecmat, LHS may only have 1 dim.
53  // Note: RHS is not transposed.
54  mlir::VectorType lhsType = op.getLhsType();
55  mlir::VectorType rhsType = op.getRhsType();
56  // Avoid 0-D vectors and 1-D rhs:
57  if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2)
58  return failure();
59  // This codegen does not work for scalable vectors. Return failure so this
60  // pattern is not accidentally chosen over patterns that lower to ArmSVE.
61  if (lhsType.isScalable() || rhsType.isScalable())
62  return failure();
63  auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
64  auto dimN = rhsType.getDimSize(0);
65  auto dimK = rhsType.getDimSize(1);
66  bool isVecmat = dimM == 1 ? true : false;
67  if (lhsType.getDimSize(lhsType.getRank() - 1) !=
68  rhsType.getDimSize(rhsType.getRank() - 1)) {
69  return failure(); // dimK mismatch
70  }
71  // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
72  // tiling.
73  if ((dimM % 2 != 0 && !isVecmat) || dimN % 2 != 0 || dimK % 8 != 0) {
74  return failure();
75  }
76 
77  // Check iterator types for contract. All iterators except inner-most
78  // dimension must be parallel.
79  auto iteratorTypes = op.getIteratorTypesArray();
80  if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] !=
81  vector::IteratorType::reduction) {
82  return failure();
83  }
84  if (llvm::any_of(ArrayRef<vector::IteratorType>(iteratorTypes).drop_back(1),
85  [](vector::IteratorType iteratorType) {
86  return iteratorType != vector::IteratorType::parallel;
87  })) {
88  return failure();
89  }
90 
91  // Check two extsi inputs Rhs Lhs for contract.
92  arith::ExtSIOp origLhsExtOp =
93  dyn_cast_or_null<arith::ExtSIOp>(op.getLhs().getDefiningOp());
94  arith::ExtSIOp origRhsExtOp =
95  dyn_cast_or_null<arith::ExtSIOp>(op.getRhs().getDefiningOp());
96  if (!origLhsExtOp || !origRhsExtOp) {
97  return failure();
98  }
99 
100  // Match any iX to i32 for X<8 then turn into an i8 output. Feed into
101  // following neon instruction. Check inputs for extsi are <=i8
102  Value extsiLhs;
103  Value extsiRhs;
104  if (auto lhsExtInType =
105  dyn_cast<mlir::VectorType>(origLhsExtOp.getIn().getType())) {
106  if (lhsExtInType.getElementTypeBitWidth() <= 8) {
107  Type targetLhsExtTy =
108  matchContainerType(rewriter.getI8Type(), lhsExtInType);
109  extsiLhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetLhsExtTy,
110  origLhsExtOp.getIn());
111  }
112  }
113  if (auto rhsExtInType =
114  dyn_cast<mlir::VectorType>(origRhsExtOp.getIn().getType())) {
115  if (rhsExtInType.getElementTypeBitWidth() <= 8) {
116  Type targetRhsExtTy =
117  matchContainerType(rewriter.getI8Type(), rhsExtInType);
118  extsiRhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetRhsExtTy,
119  origRhsExtOp.getIn());
120  }
121  }
122 
123  if (!extsiLhs || !extsiRhs) {
124  return failure();
125  }
126 
127  // Initial accumulator for the final result. This is the un-tiled result if
128  // tiling is done.
129  Value result = rewriter.create<arith::ConstantOp>(
130  loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType()));
131 
132  SmallVector<int64_t> unrolledSize = *op.getShapeForUnroll();
133  SmallVector<int64_t> smmlaShape = {2, 8};
134  SmallVector<int64_t> loopOrder = {0, 1};
135  if (unrolledSize.size() == 3) {
136  smmlaShape.insert(smmlaShape.begin(), isVecmat ? 1 : 2);
137  loopOrder.push_back(2);
138  }
139 
140  // Keep track of the previous accumulator when tiling over K.
141  Value kAcc;
142  for (SmallVector<int64_t> offsets :
143  StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) {
144  // Helper to compute the new shape of each operand and extract the slice.
145  auto extractOperand = [&](Value operand, AffineMap permutationMap,
146  ArrayRef<int64_t> operandOffsets) {
147  SmallVector<int64_t> operandShape =
148  applyPermutationMap(permutationMap, ArrayRef<int64_t>(smmlaShape));
149  SmallVector<int64_t> operandStrides(operandOffsets.size(), 1);
150  return rewriter.createOrFold<vector::ExtractStridedSliceOp>(
151  loc, operand, operandOffsets, operandShape, operandStrides);
152  };
153 
154  // Extract tiled lhs, rhs, and acc
155  AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0];
156  SmallVector<int64_t> lhsOffsets =
157  applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
158  Value tiledLhs = extractOperand(extsiLhs, lhsPermutationMap, lhsOffsets);
159  AffineMap rhsPermutationMap = op.getIndexingMapsArray()[1];
160  SmallVector<int64_t> rhsOffsets =
161  applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
162  Value tiledRhs = extractOperand(extsiRhs, rhsPermutationMap, rhsOffsets);
163  AffineMap accPermutationMap = op.getIndexingMapsArray()[2];
164  SmallVector<int64_t> accOffsets =
165  applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
166  Value tiledAcc =
167  extractOperand(op.getAcc(), accPermutationMap, accOffsets);
168 
169  auto inputElementType =
170  cast<ShapedType>(tiledLhs.getType()).getElementType();
171  auto accElementType =
172  cast<ShapedType>(tiledAcc.getType()).getElementType();
173  auto inputExpandedType = VectorType::get({2, 8}, inputElementType);
174  auto outputExpandedType = VectorType::get({2, 2}, accElementType);
175 
176  // With vecmat, tiled LHS and ACC will contain only one of 2 necessary
177  // rows along dimM. Expand their shapes to match the smmla op.
178  if (isVecmat) {
179  auto expandForSMMLA = [&](Value tiledOperand,
180  VectorType expandedTypeType) {
181  auto emptyOperand = rewriter.create<arith::ConstantOp>(
182  loc, expandedTypeType, rewriter.getZeroAttr(expandedTypeType));
183  SmallVector<int64_t> offsets(
184  cast<ShapedType>(emptyOperand.getType()).getRank(), 0);
185  SmallVector<int64_t> strides(
186  cast<ShapedType>(tiledOperand.getType()).getRank(), 1);
187  return rewriter.createOrFold<vector::InsertStridedSliceOp>(
188  loc, tiledOperand, emptyOperand, offsets, strides);
189  };
190  tiledLhs = expandForSMMLA(tiledLhs, inputExpandedType);
191  tiledAcc = expandForSMMLA(tiledAcc, outputExpandedType);
192  }
193 
194  // Collapse tiled operands to 1D vectors required by smmla intrinsic
195  auto collapsedInputType =
196  VectorType::get(inputExpandedType.getNumElements(), inputElementType);
197  auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
198  tiledLhs.getLoc(), collapsedInputType, tiledLhs);
199  auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
200  tiledRhs.getLoc(), collapsedInputType, tiledRhs);
201  auto collapsedOutputType =
202  VectorType::get(outputExpandedType.getNumElements(), accElementType);
203 
204  bool initialKAcc = offsets.back() == 0;
205  Value collapsedRes;
206  if (!initialKAcc) {
207  collapsedRes = kAcc;
208  } else {
209  collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
210  tiledAcc.getLoc(), collapsedOutputType, tiledAcc);
211  }
212 
213  // Insert contract op
214  kAcc = rewriter.createOrFold<arm_neon::SmmlaOp>(
215  op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
216  collapsedRhs);
217 
218  // Reshape output back to 2D
219  Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
220  kAcc.getLoc(), tiledAcc.getType(), kAcc);
221 
222  // With vecmat, only one row of tiled ACC can be inserted into file result
223  if (isVecmat) {
224  tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
225  }
226 
227  // Insert the tiled result back into the non tiled result of the
228  // contract op.
229  SmallVector<int64_t> strides(
230  cast<ShapedType>(tiledRes.getType()).getRank(), 1);
231  result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
232  loc, tiledRes, result, accOffsets, strides);
233  }
234 
235  rewriter.replaceOp(op, result);
236  return success();
237  }
238 };
239 
240 } // namespace
241 
244  MLIRContext *context = patterns.getContext();
245  patterns.add<LowerContractionToSMMLAPattern>(context, /*benefit=*/2);
246 }
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:320
IntegerType getI8Type()
Definition: Builders.cpp:59
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
void populateLowerContractionToSMMLAPatternPatterns(RewritePatternSet &patterns)
Include the generated interface declarations.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:675
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool isVecmat(ArrayAttr indexingMaps)
Tests whether the given maps describe a vector matrix multiplication.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319