MLIR  20.0.0git
LowerContractionToSMMLAPattern.cpp
Go to the documentation of this file.
1 //===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering patterns from vector.contract to
10 // arm_neon.intr.smmla
11 //
12 //===---
13 
21 #include "mlir/IR/AffineMap.h"
22 #include "mlir/IR/PatternMatch.h"
24 
25 #define DEBUG_TYPE "lower-contract-to-arm-neon"
26 
27 using namespace mlir;
28 using namespace mlir::arm_neon;
29 
30 namespace {
31 
32 /// Return the shaped type with new element type.
33 static Type matchContainerType(Type element, Type container) {
34  if (auto shapedTy = dyn_cast<ShapedType>(container)) {
35  return shapedTy.clone(element);
36  }
37  return element;
38 }
39 
40 /// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
41 /// any vector.contract into multiple smmla instructions with unrolling so long
42 /// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
43 /// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is
44 /// necessary, a single smmla instruction is emitted.
45 class LowerContractionToSMMLAPattern
46  : public OpRewritePattern<vector::ContractionOp> {
47 public:
49  LogicalResult matchAndRewrite(vector::ContractionOp op,
50  PatternRewriter &rewriter) const override {
51  Location loc = op.getLoc();
52  // Infer tile sizes from operands. For vecmat, LHS may only have 1 dim.
53  // Note: RHS is not transposed.
54  mlir::VectorType lhsType = op.getLhsType();
55  mlir::VectorType rhsType = op.getRhsType();
56  // Avoid 0-D vectors and 1-D rhs:
57  if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2)
58  return failure();
59  auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
60  auto dimN = rhsType.getDimSize(0);
61  auto dimK = rhsType.getDimSize(1);
62  bool isVecmat = dimM == 1 ? true : false;
63  if (lhsType.getDimSize(lhsType.getRank() - 1) !=
64  rhsType.getDimSize(rhsType.getRank() - 1)) {
65  return failure(); // dimK mismatch
66  }
67  // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
68  // tiling.
69  if ((dimM % 2 != 0 && !isVecmat) || dimN % 2 != 0 || dimK % 8 != 0) {
70  return failure();
71  }
72 
73  // Check iterator types for contract. All iterators except inner-most
74  // dimension must be parallel.
75  auto iteratorTypes = op.getIteratorTypesArray();
76  if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] !=
77  vector::IteratorType::reduction) {
78  return failure();
79  }
80  if (llvm::any_of(ArrayRef<vector::IteratorType>(iteratorTypes).drop_back(1),
81  [](vector::IteratorType iteratorType) {
82  return iteratorType != vector::IteratorType::parallel;
83  })) {
84  return failure();
85  }
86 
87  // Check two extsi inputs Rhs Lhs for contract.
88  arith::ExtSIOp origLhsExtOp =
89  dyn_cast_or_null<arith::ExtSIOp>(op.getLhs().getDefiningOp());
90  arith::ExtSIOp origRhsExtOp =
91  dyn_cast_or_null<arith::ExtSIOp>(op.getRhs().getDefiningOp());
92  if (!origLhsExtOp || !origRhsExtOp) {
93  return failure();
94  }
95 
96  // Match any iX to i32 for X<8 then turn into an i8 output. Feed into
97  // following neon instruction. Check inputs for extsi are <=i8
98  Value extsiLhs;
99  Value extsiRhs;
100  if (auto lhsExtInType =
101  dyn_cast<mlir::VectorType>(origLhsExtOp.getIn().getType())) {
102  if (lhsExtInType.getElementTypeBitWidth() <= 8) {
103  Type targetLhsExtTy =
104  matchContainerType(rewriter.getI8Type(), lhsExtInType);
105  extsiLhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetLhsExtTy,
106  origLhsExtOp.getIn());
107  }
108  }
109  if (auto rhsExtInType =
110  dyn_cast<mlir::VectorType>(origRhsExtOp.getIn().getType())) {
111  if (rhsExtInType.getElementTypeBitWidth() <= 8) {
112  Type targetRhsExtTy =
113  matchContainerType(rewriter.getI8Type(), rhsExtInType);
114  extsiRhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetRhsExtTy,
115  origRhsExtOp.getIn());
116  }
117  }
118 
119  if (!extsiLhs || !extsiRhs) {
120  return failure();
121  }
122 
123  // Initial accumulator for the final result. This is the un-tiled result if
124  // tiling is done.
125  Value result = rewriter.create<arith::ConstantOp>(
126  loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType()));
127 
128  SmallVector<int64_t> unrolledSize = *op.getShapeForUnroll();
129  SmallVector<int64_t> smmlaShape{2, 8};
130  SmallVector<int64_t> loopOrder{0, 1};
131  if (unrolledSize.size() == 3) {
132  smmlaShape.insert(smmlaShape.begin(), isVecmat ? 1 : 2);
133  loopOrder.push_back(2);
134  }
135 
136  // Keep track of the previous accumulator when tiling over K.
137  Value kAcc;
138  for (SmallVector<int64_t> offsets :
139  StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) {
140  // Helper to compute the new shape of each operand and extract the slice.
141  auto extractOperand = [&](Value operand, AffineMap permutationMap,
142  ArrayRef<int64_t> operandOffsets) {
143  SmallVector<int64_t> operandShape =
144  applyPermutationMap(permutationMap, ArrayRef<int64_t>(smmlaShape));
145  SmallVector<int64_t> operandStrides(operandOffsets.size(), 1);
146  return rewriter.createOrFold<vector::ExtractStridedSliceOp>(
147  loc, operand, operandOffsets, operandShape, operandStrides);
148  };
149 
150  // Extract tiled lhs, rhs, and acc
151  AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0];
152  SmallVector<int64_t> lhsOffsets =
153  applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
154  Value tiledLhs = extractOperand(extsiLhs, lhsPermutationMap, lhsOffsets);
155  AffineMap rhsPermutationMap = op.getIndexingMapsArray()[1];
156  SmallVector<int64_t> rhsOffsets =
157  applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
158  Value tiledRhs = extractOperand(extsiRhs, rhsPermutationMap, rhsOffsets);
159  AffineMap accPermutationMap = op.getIndexingMapsArray()[2];
160  SmallVector<int64_t> accOffsets =
161  applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
162  Value tiledAcc =
163  extractOperand(op.getAcc(), accPermutationMap, accOffsets);
164 
165  auto inputElementType =
166  cast<ShapedType>(tiledLhs.getType()).getElementType();
167  auto accElementType =
168  cast<ShapedType>(tiledAcc.getType()).getElementType();
169  auto inputExpandedType = VectorType::get({2, 8}, inputElementType);
170  auto outputExpandedType = VectorType::get({2, 2}, accElementType);
171 
172  // With vecmat, tiled LHS and ACC will contain only one of 2 necessary
173  // rows along dimM. Expand their shapes to match the smmla op.
174  if (isVecmat) {
175  auto expandForSMMLA = [&](Value tiledOperand,
176  VectorType expandedTypeType) {
177  auto emptyOperand = rewriter.create<arith::ConstantOp>(
178  loc, expandedTypeType, rewriter.getZeroAttr(expandedTypeType));
179  SmallVector<int64_t> offsets(
180  cast<ShapedType>(emptyOperand.getType()).getRank(), 0);
181  SmallVector<int64_t> strides(
182  cast<ShapedType>(tiledOperand.getType()).getRank(), 1);
183  return rewriter.createOrFold<vector::InsertStridedSliceOp>(
184  loc, tiledOperand, emptyOperand, offsets, strides);
185  };
186  tiledLhs = expandForSMMLA(tiledLhs, inputExpandedType);
187  tiledAcc = expandForSMMLA(tiledAcc, outputExpandedType);
188  }
189 
190  // Collapse tiled operands to 1D vectors required by smmla intrinsic
191  auto collapsedInputType =
192  VectorType::get(inputExpandedType.getNumElements(), inputElementType);
193  auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
194  tiledLhs.getLoc(), collapsedInputType, tiledLhs);
195  auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
196  tiledRhs.getLoc(), collapsedInputType, tiledRhs);
197  auto collapsedOutputType =
198  VectorType::get(outputExpandedType.getNumElements(), accElementType);
199 
200  bool initialKAcc = offsets.back() == 0;
201  Value collapsedRes;
202  if (!initialKAcc) {
203  collapsedRes = kAcc;
204  } else {
205  collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
206  tiledAcc.getLoc(), collapsedOutputType, tiledAcc);
207  }
208 
209  // Insert contract op
210  kAcc = rewriter.createOrFold<arm_neon::SmmlaOp>(
211  op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
212  collapsedRhs);
213 
214  // Reshape output back to 2D
215  Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
216  kAcc.getLoc(), tiledAcc.getType(), kAcc);
217 
218  // With vecmat, only one row of tiled ACC can be inserted into file result
219  if (isVecmat) {
220  tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
221  }
222 
223  // Insert the tiled result back into the non tiled result of the
224  // contract op.
225  SmallVector<int64_t> strides(
226  cast<ShapedType>(tiledRes.getType()).getRank(), 1);
227  result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
228  loc, tiledRes, result, accOffsets, strides);
229  }
230 
231  rewriter.replaceOp(op, result);
232  return success();
233  }
234 };
235 
236 } // namespace
237 
239  RewritePatternSet &patterns) {
240  MLIRContext *context = patterns.getContext();
241  patterns.add<LowerContractionToSMMLAPattern>(context, /*benefit=*/1);
242 }
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:364
IntegerType getI8Type()
Definition: Builders.cpp:103
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:528
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
void populateLowerContractionToSMMLAPatternPatterns(RewritePatternSet &patterns)
Include the generated interface declarations.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:675
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool isVecmat(ArrayAttr indexingMaps)
Tests whether the given maps describe a vector matrix multiplication.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362