MLIR  19.0.0git
LowerContractionToSMMLAPattern.cpp
Go to the documentation of this file.
1 //===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering patterns from vector.contract to
10 // arm_neon.intr.smmla
11 //
12 //===---
13 
21 #include "mlir/IR/AffineMap.h"
22 #include "mlir/IR/PatternMatch.h"
25 
26 #define DEBUG_TYPE "lower-contract-to-arm-neon"
27 
28 using namespace mlir;
29 using namespace mlir::arm_neon;
30 
31 namespace {
32 
33 /// Return the shaped type with new element type.
34 static Type matchContainerType(Type element, Type container) {
35  if (auto shapedTy = dyn_cast<ShapedType>(container)) {
36  return shapedTy.clone(element);
37  }
38  return element;
39 }
40 
41 /// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
42 /// any vector.contract into multiple smmla instructions with unrolling so long
43 /// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
44 /// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is
45 /// necessary, a single smmla instruction is emitted.
46 class LowerContractionToSMMLAPattern
47  : public OpRewritePattern<vector::ContractionOp> {
48 public:
50  LogicalResult matchAndRewrite(vector::ContractionOp op,
51  PatternRewriter &rewriter) const override {
52  Location loc = op.getLoc();
53  // Infer tile sizes from operands. For vecmat, LHS may only have 1 dim.
54  // Note: RHS is not transposed.
55  mlir::VectorType lhsType = op.getLhsType();
56  mlir::VectorType rhsType = op.getRhsType();
57  // Avoid 0-D vectors and 1-D rhs:
58  if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2)
59  return failure();
60  auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
61  auto dimN = rhsType.getDimSize(0);
62  auto dimK = rhsType.getDimSize(1);
63  bool isVecmat = dimM == 1 ? true : false;
64  if (lhsType.getDimSize(lhsType.getRank() - 1) !=
65  rhsType.getDimSize(rhsType.getRank() - 1)) {
66  return failure(); // dimK mismatch
67  }
68  // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
69  // tiling.
70  if ((dimM % 2 != 0 && !isVecmat) || dimN % 2 != 0 || dimK % 8 != 0) {
71  return failure();
72  }
73 
74  // Check iterator types for contract. All iterators except inner-most
75  // dimension must be parallel.
76  auto iteratorTypes = op.getIteratorTypesArray();
77  if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] !=
78  vector::IteratorType::reduction) {
79  return failure();
80  }
81  if (llvm::any_of(ArrayRef<vector::IteratorType>(iteratorTypes).drop_back(1),
82  [](vector::IteratorType iteratorType) {
83  return iteratorType != vector::IteratorType::parallel;
84  })) {
85  return failure();
86  }
87 
88  // Check two extsi inputs Rhs Lhs for contract.
89  arith::ExtSIOp origLhsExtOp =
90  dyn_cast_or_null<arith::ExtSIOp>(op.getLhs().getDefiningOp());
91  arith::ExtSIOp origRhsExtOp =
92  dyn_cast_or_null<arith::ExtSIOp>(op.getRhs().getDefiningOp());
93  if (!origLhsExtOp || !origRhsExtOp) {
94  return failure();
95  }
96 
97  // Match any iX to i32 for X<8 then turn into an i8 output. Feed into
98  // following neon instruction. Check inputs for extsi are <=i8
99  Value extsiLhs;
100  Value extsiRhs;
101  if (auto lhsExtInType =
102  dyn_cast<mlir::VectorType>(origLhsExtOp.getIn().getType())) {
103  if (lhsExtInType.getElementTypeBitWidth() <= 8) {
104  Type targetLhsExtTy =
105  matchContainerType(rewriter.getI8Type(), lhsExtInType);
106  extsiLhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetLhsExtTy,
107  origLhsExtOp.getIn());
108  }
109  }
110  if (auto rhsExtInType =
111  dyn_cast<mlir::VectorType>(origRhsExtOp.getIn().getType())) {
112  if (rhsExtInType.getElementTypeBitWidth() <= 8) {
113  Type targetRhsExtTy =
114  matchContainerType(rewriter.getI8Type(), rhsExtInType);
115  extsiRhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetRhsExtTy,
116  origRhsExtOp.getIn());
117  }
118  }
119 
120  if (!extsiLhs || !extsiRhs) {
121  return failure();
122  }
123 
124  // Initial accumulator for the final result. This is the un-tiled result if
125  // tiling is done.
126  Value result = rewriter.create<arith::ConstantOp>(
127  loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType()));
128 
129  SmallVector<int64_t> unrolledSize = *op.getShapeForUnroll();
130  SmallVector<int64_t> smmlaShape{2, 8};
131  SmallVector<int64_t> loopOrder{0, 1};
132  if (unrolledSize.size() == 3) {
133  smmlaShape.insert(smmlaShape.begin(), isVecmat ? 1 : 2);
134  loopOrder.push_back(2);
135  }
136 
137  // Keep track of the previous accumulator when tiling over K.
138  Value kAcc;
139  for (SmallVector<int64_t> offsets :
140  StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) {
141  // Helper to compute the new shape of each operand and extract the slice.
142  auto extractOperand = [&](Value operand, AffineMap permutationMap,
143  ArrayRef<int64_t> operandOffsets) {
144  SmallVector<int64_t> operandShape =
145  applyPermutationMap(permutationMap, ArrayRef<int64_t>(smmlaShape));
146  SmallVector<int64_t> operandStrides(operandOffsets.size(), 1);
147  return rewriter.createOrFold<vector::ExtractStridedSliceOp>(
148  loc, operand, operandOffsets, operandShape, operandStrides);
149  };
150 
151  // Extract tiled lhs, rhs, and acc
152  AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0];
153  SmallVector<int64_t> lhsOffsets =
154  applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
155  Value tiledLhs = extractOperand(extsiLhs, lhsPermutationMap, lhsOffsets);
156  AffineMap rhsPermutationMap = op.getIndexingMapsArray()[1];
157  SmallVector<int64_t> rhsOffsets =
158  applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
159  Value tiledRhs = extractOperand(extsiRhs, rhsPermutationMap, rhsOffsets);
160  AffineMap accPermutationMap = op.getIndexingMapsArray()[2];
161  SmallVector<int64_t> accOffsets =
162  applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
163  Value tiledAcc =
164  extractOperand(op.getAcc(), accPermutationMap, accOffsets);
165 
166  auto inputElementType =
167  cast<ShapedType>(tiledLhs.getType()).getElementType();
168  auto accElementType =
169  cast<ShapedType>(tiledAcc.getType()).getElementType();
170  auto inputExpandedType = VectorType::get({2, 8}, inputElementType);
171  auto outputExpandedType = VectorType::get({2, 2}, accElementType);
172 
173  // With vecmat, tiled LHS and ACC will contain only one of 2 necessary
174  // rows along dimM. Expand their shapes to match the smmla op.
175  if (isVecmat) {
176  auto expandForSMMLA = [&](Value tiledOperand,
177  VectorType expandedTypeType) {
178  auto emptyOperand = rewriter.create<arith::ConstantOp>(
179  loc, expandedTypeType, rewriter.getZeroAttr(expandedTypeType));
180  SmallVector<int64_t> offsets(
181  cast<ShapedType>(emptyOperand.getType()).getRank(), 0);
182  SmallVector<int64_t> strides(
183  cast<ShapedType>(tiledOperand.getType()).getRank(), 1);
184  return rewriter.createOrFold<vector::InsertStridedSliceOp>(
185  loc, tiledOperand, emptyOperand, offsets, strides);
186  };
187  tiledLhs = expandForSMMLA(tiledLhs, inputExpandedType);
188  tiledAcc = expandForSMMLA(tiledAcc, outputExpandedType);
189  }
190 
191  // Collapse tiled operands to 1D vectors required by smmla intrinsic
192  auto collapsedInputType =
193  VectorType::get(inputExpandedType.getNumElements(), inputElementType);
194  auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
195  tiledLhs.getLoc(), collapsedInputType, tiledLhs);
196  auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
197  tiledRhs.getLoc(), collapsedInputType, tiledRhs);
198  auto collapsedOutputType =
199  VectorType::get(outputExpandedType.getNumElements(), accElementType);
200 
201  bool initialKAcc = offsets.back() == 0;
202  Value collapsedRes;
203  if (!initialKAcc) {
204  collapsedRes = kAcc;
205  } else {
206  collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
207  tiledAcc.getLoc(), collapsedOutputType, tiledAcc);
208  }
209 
210  // Insert contract op
211  kAcc = rewriter.createOrFold<arm_neon::SmmlaOp>(
212  op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
213  collapsedRhs);
214 
215  // Reshape output back to 2D
216  Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
217  kAcc.getLoc(), tiledAcc.getType(), kAcc);
218 
219  // With vecmat, only one row of tiled ACC can be inserted into file result
220  if (isVecmat) {
221  tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
222  }
223 
224  // Insert the tiled result back into the non tiled result of the
225  // contract op.
226  SmallVector<int64_t> strides(
227  cast<ShapedType>(tiledRes.getType()).getRank(), 1);
228  result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
229  loc, tiledRes, result, accOffsets, strides);
230  }
231 
232  rewriter.replaceOp(op, result);
233  return success();
234  }
235 };
236 
237 } // namespace
238 
240  RewritePatternSet &patterns) {
241  MLIRContext *context = patterns.getContext();
242  patterns.add<LowerContractionToSMMLAPattern>(context, /*benefit=*/1);
243 }
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
IntegerType getI8Type()
Definition: Builders.cpp:79
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:522
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
void populateLowerContractionToSMMLAPatternPatterns(RewritePatternSet &patterns)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:650
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool isVecmat(ArrayAttr indexingMaps)
Tests whether the given maps describe a vector matrix multiplication.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362