MLIR  21.0.0git
LowerContractionToSVEI8MMPattern.cpp
Go to the documentation of this file.
1 //===- LowerContractionToSVEI8MMPattern.cpp - Contract to I8MM --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering patterns from vector.contract to operations
10 // that map to instructions from the SVE FEAT_I8MM extension.
11 //
12 // TODO: There may be opportunities to unify this with a similar pattern
13 // for Neon. See:
14 // https://github.com/llvm/llvm-project/issues/145559
15 // LowerContractionToNeonI8MMPattern.cpp
16 //
17 //===----------------------------------------------------------------------===//
18 
26 #include "mlir/IR/AffineMap.h"
27 #include "mlir/IR/PatternMatch.h"
29 
31 
32 #define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
33 
34 using namespace mlir;
35 
36 namespace {
37 // Get the operand of a `vector.contract`. This function is intended to abstract
38 // away from the particular way a value is extended before feeding it into the
39 // `vector.contract` - via zero-extend or an explicit or implicit sign-extend
40 // (for implicit sign-extension see `vector.contract` documentation).
41 //
42 // The template parameter `Op` indicates the extension operation (explicit or
43 // implicit) for which we are checking.
44 //
45 // Return success only for extensions from `i8` to `i32`.
46 template <typename Op>
47 std::optional<Value> getExtOperand(Value v) {
48 
49  static_assert(llvm::is_one_of<Op, arith::ExtSIOp, arith::ExtUIOp>::value,
50  "Must be instantiated with either sign- or zero- extension op");
51 
52  // If the operand is not defined by an explicit extend operation of the
53  // accepted operation type allow for an implicit sign-extension.
54  auto extOp = dyn_cast_or_null<Op>(v.getDefiningOp());
55  if (!extOp) {
56  if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
57  auto vTy = cast<VectorType>(v.getType());
58  if (!vTy.getElementType().isSignlessInteger(8))
59  return {};
60  return v;
61  }
62  return {};
63  }
64 
65  // If the operand is defined by an explicit extend operation of the accepted
66  // operation type, check it's extended from `i8` to `i32`.
67  auto inOp = extOp.getIn();
68  auto inTy = dyn_cast<VectorType>(inOp.getType());
69  if (!inTy || !inTy.getElementType().isSignlessInteger(8))
70  return {};
71 
72  auto outTy = dyn_cast<VectorType>(extOp.getType());
73  if (!outTy || !outTy.getElementType().isSignlessInteger(32))
74  return {};
75 
76  return inOp;
77 }
78 
79 // Designate the operation (resp. instruction) used to do sub-tile matrix
80 // multiplications.
81 enum class MMLA {
82  Signed, // smmla
83  Unsigned, // ummla
84  Mixed, // usmmla
85  MixedSwapped // usmmla with LHS and RHS swapped
86 };
87 
88 // Create the matrix mulitply and accumulate operation according to `op`.
89 Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
90  mlir::VectorType accType, Value acc, Value lhs, Value rhs) {
91  switch (op) {
92  case MMLA::Signed:
93  return rewriter.create<arm_sve::SmmlaOp>(loc, accType, acc, lhs, rhs);
94  case MMLA::Unsigned:
95  return rewriter.create<arm_sve::UmmlaOp>(loc, accType, acc, lhs, rhs);
96  case MMLA::Mixed:
97  return rewriter.create<arm_sve::UsmmlaOp>(loc, accType, acc, lhs, rhs);
98  case MMLA::MixedSwapped:
99  // The accumulator comes transposed and the result will be transposed
100  // later, so all we have to do here is swap the operands.
101  return rewriter.create<arm_sve::UsmmlaOp>(loc, accType, acc, rhs, lhs);
102  }
103 }
104 
105 /// Lower a contraction operation that performs a matrix multiplication
106 /// of two 8-bit integer matrix tiles with logical dimensions <Mx8> and <8x[N]>
107 /// for the left-hand side and the right-hand side, respectively,
108 /// yielding a <Mx[N]> 32-bit integer result.
109 ///
110 /// The operands' shapes are such that the operands can be evenly split into
111 /// sub-tiles with dimensions as expected by the targeted FEAT_I8MM
112 /// instructions. The intent is that M and N are chosen (by higher level
113 /// transforms) in such a way as to maximise register usage. The main use case
114 /// we envision as of now is MMT4D, thus the RHS operand is expected
115 /// pre-transposed.
116 ///
117 /// The matrix multiplication is performed by unrolling the usual tiled matrix
118 /// multiplication algorithm using sub-tiles with dimensions <2x8> for the LHS,
119 /// <8x[2]> for the RHS, and <2x[2]> for the result and the input accumulator.
120 ///
121 /// One way to illustrate the operation is as follows:
122 ///
123 /// RHS<8x[N]>: <8x[2]> <8x[2]> ... <8x[2]>
124 /// +-----------------------------
125 /// LHS<Mx8>: <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
126 /// <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
127 /// ... | ... ... ... ...
128 /// <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
129 ///
130 /// The RHS operand is unpacked into N/2 values, each representing a sequence of
131 /// VSCALE number of sub-tiles with dimensions <8x2>.
132 /// The LHS operand is initially unpacked into M/2 values, each representing a
133 /// sub-tile with dimensions <2x8>, and then each such sub-tile is replicated
134 /// VSCALE times.
135 /// Multiplying thus replicated LHS sub-tile by the corresponding RHS sub-tile
136 /// correctly computes an entire result sub-tile.
137 class LowerContractionToSVEI8MMPattern
138  : public OpRewritePattern<vector::ContractionOp> {
139 public:
141  LogicalResult matchAndRewrite(vector::ContractionOp op,
142  PatternRewriter &rewriter) const override {
143 
144  Location loc = op.getLoc();
145  mlir::VectorType lhsType = op.getLhsType();
146  mlir::VectorType rhsType = op.getRhsType();
147 
148  // Check the rank the types so we can safely examine their dimensions.
149  if (lhsType.getRank() != 2 || rhsType.getRank() != 2)
150  return rewriter.notifyMatchFailure(op, "non-matching operand shape");
151 
152  auto M = lhsType.getDimSize(0);
153  auto N = rhsType.getDimSize(0);
154  auto K = rhsType.getDimSize(1);
155 
156  // Check the operands have the expected shape:
157  // * for LHS: fixed vector MxK
158  // * for RHS: scalable vector [N]xK
159  // * K == 8
160  // * M and N even and at least 2
161  if (lhsType.isScalable() || !rhsType.getScalableDims()[0] ||
162  rhsType.getScalableDims()[1] || lhsType.getDimSize(1) != K || K != 8 ||
163  M < 2 || M % 2 != 0 || N < 2 || N % 2 != 0 ||
164  !rhsType.getScalableDims()[0])
165  return rewriter.notifyMatchFailure(op, "non-matching operand shape");
166 
167  // Check permutation maps. For now only accept
168  // lhs: (d0, d1, d2) -> (d0, d2)
169  // rhs: (d0, d1, d2) -> (d1, d2)
170  // acc: (d0, d1, d2) -> (d0, d1)
171  // This corresponds to matrix multiplication with transposed RHS.
172  if (op.getIndexingMapsArray()[0] !=
174  op.getContext()) ||
175  op.getIndexingMapsArray()[1] !=
177  op.getContext()) ||
178  op.getIndexingMapsArray()[2] !=
180  op.getContext()))
181  return rewriter.notifyMatchFailure(op, "non-matching permutation maps");
182 
183  // Check iterator types for matrix multiplication.
184  auto itTypes = op.getIteratorTypesArray();
185  if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
186  itTypes[1] != vector::IteratorType::parallel ||
187  itTypes[2] != vector::IteratorType::reduction)
188  return rewriter.notifyMatchFailure(
189  op, "iterator types do not correspond to matrix multiplication");
190 
191  // Check the combining kind is addition.
192  if (op.getKind() != vector::CombiningKind::ADD)
193  return rewriter.notifyMatchFailure(op,
194  "combining kind is not an addition");
195 
196  // Check the output is a vector of i32 elements.
197  auto outTy = dyn_cast<VectorType>(op.getResultType());
198  if (!outTy || outTy.getElementType() != rewriter.getI32Type())
199  return rewriter.notifyMatchFailure(op,
200  "output type is not a vector of i32");
201 
202  // Check inputs are sign-/zero- extensions from i8 to i32. Get the values
203  // before the extension. All four signed/unsigned combinations for input
204  // operands are supported, but they are lowered to different operations.
205  // Determine which is the appropriate operation to lower to.
206  MMLA mmlaOp = MMLA::Signed;
207  auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
208  if (!maybeLhs) {
209  mmlaOp = MMLA::Unsigned;
210  maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
211  }
212  if (!maybeLhs)
213  return rewriter.notifyMatchFailure(
214  op, "LHS is not a sign- or zero- extended i8");
215 
216  auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs());
217  if (maybeRhs) {
218  if (mmlaOp == MMLA::Unsigned)
219  mmlaOp = MMLA::Mixed;
220  } else {
221  if (mmlaOp == MMLA::Signed)
222  mmlaOp = MMLA::MixedSwapped;
223  maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs());
224  }
225  if (!maybeRhs)
226  return rewriter.notifyMatchFailure(
227  op, "RHS is not a sign- or zero- extended i8");
228 
229  // One-dimensional vector types for arm_sve.*mmla
230  auto nxv16i8 = VectorType::get(/*shape=*/16, rewriter.getI8Type(),
231  /*scalableDims=*/{true});
232  auto nxv4i32 = VectorType::get(/*shape=*/4, rewriter.getI32Type(),
233  /*scalableDims=*/{true});
234 
235  // Extract LHS sub-tiles with logicall shape <2x8>.
236  SmallVector<Value> lhsTile;
237  for (int64_t i = 0; i < M; i += 2) {
238  // Extract two consecutive rows of the LHS tile.
239  auto r0 = rewriter.create<vector::ExtractOp>(loc, *maybeLhs,
240  ArrayRef<int64_t>{i});
241  auto r1 = rewriter.create<vector::ExtractOp>(loc, *maybeLhs,
242  ArrayRef<int64_t>{i + 1});
243  // Concatenate to obtain a 16 x i8 flattened sub-tile.
244  auto t = rewriter.create<vector::ShuffleOp>(
245  loc, r0, r1,
246  llvm::ArrayRef<int64_t>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
247  14, 15});
248  // Turn it into a scalable vector.
249  auto s = rewriter.create<vector::ScalableInsertOp>(
250  loc, t, rewriter.create<ub::PoisonOp>(loc, nxv16i8), 0);
251  // Replicate the sub-tile VSCALE times to fill the entire vector.
252  auto r = rewriter.create<arm_sve::DupQLaneOp>(loc, s, 0);
253  lhsTile.push_back(r);
254  }
255 
256  // "Flatten" the RHS tile from <[N]x8> to <[8*N]>.
257  auto rhs = rewriter.create<vector::ShapeCastOp>(
258  maybeRhs->getLoc(),
259  VectorType::get(/*shape=*/8 * N, rewriter.getI8Type(),
260  /*scalableDims=*/{true}),
261  *maybeRhs);
262 
263  // Extract the RHS sub-tiles with logical shape <8x[2]>.
264  SmallVector<Value> rhsTile;
265  for (int64_t j = 0; j < N; j += 2)
266  rhsTile.push_back(
267  rewriter.create<vector::ScalableExtractOp>(loc, nxv16i8, rhs, j * 8));
268 
269  // Handy types for packing/unpacking of the accumulator tile.
270  auto accRowTy = VectorType::get(/*shape=*/N, rewriter.getI32Type(),
271  /*scalableDims=*/{true});
272  auto accRowX2Ty = VectorType::get(/*shape=*/2 * N, rewriter.getI32Type(),
273  /*scalableDims=*/{true});
274  auto accRow64Ty = VectorType::get(/*shape=*/N / 2, rewriter.getI64Type(),
275  /*scalableDims=*/{true});
276  auto accRowX264Ty = VectorType::get(/*shape=*/N, rewriter.getI64Type(),
277  /*scalableDims=*/{true});
278 
279  // Extract and pack the ACC sub-tiles.
280  SmallVector<Value> accTile;
281  for (int64_t i = 0; i < M; i += 2) {
282  // Extract two consecutive rows of the accumulator tile.
283  auto r0 = rewriter.create<vector::ExtractOp>(loc, op.getAcc(),
284  ArrayRef<int64_t>{i});
285  auto r1 = rewriter.create<vector::ExtractOp>(loc, op.getAcc(),
286  ArrayRef<int64_t>{i + 1});
287  Value accTileVec;
288  if (mmlaOp == MMLA::MixedSwapped) {
289  // We need to swap the positions of the LHS and RHS (since we don't have
290  // a signed * unsigned operation), but then each individual 2x2 tile of
291  // the acumulator and (later) the result need to be transposed.
292  accTileVec = rewriter.create<vector::InterleaveOp>(loc, r0, r1);
293  } else {
294  // Bitcast them to 64-bit elements, so subsequent
295  // interleave/deinterleave work on pairs of 32-bit numbers.
296  auto r0I64 = rewriter.create<vector::BitCastOp>(loc, accRow64Ty, r0);
297  auto r1I64 = rewriter.create<vector::BitCastOp>(loc, accRow64Ty, r1);
298 
299  // Interleave the rows, effectively flattening each 2x2 tile into 4
300  // consecutive elements.
301  auto intrI64 = rewriter.create<vector::InterleaveOp>(loc, r0I64, r1I64);
302 
303  // Bitcast back to 32-bit elements.
304  accTileVec =
305  rewriter.create<vector::BitCastOp>(loc, accRowX2Ty, intrI64);
306  }
307  // Extract ACC sub-tiles.
308  for (int64_t j = 0; j < N; j += 2)
309  accTile.push_back(rewriter.create<vector::ScalableExtractOp>(
310  loc, nxv4i32, accTileVec, j * 2));
311  }
312 
313  // Emit sub-tile matrix multiplications.
314  SmallVector<Value> outTile;
315  for (int64_t i = 0; i < M / 2; ++i)
316  for (int64_t j = 0; j < N / 2; ++j) {
317  Value mmla = createMMLA(rewriter, mmlaOp, loc, nxv4i32,
318  accTile[i * N / 2 + j], lhsTile[i], rhsTile[j]);
319  outTile.push_back(mmla);
320  }
321 
322  // Unpack the OUT sub-tiles and insert into the result.
323  Value result = rewriter.create<ub::PoisonOp>(loc, op.getResultType());
324  for (int64_t i = 0; i < M / 2; ++i) {
325  // Collect a number of sub-tiles in a row.
326  Value row = rewriter.create<ub::PoisonOp>(loc, accRowX2Ty);
327  for (int64_t j = 0; j < N / 2; ++j)
328  row = rewriter.create<vector::ScalableInsertOp>(
329  loc, outTile[i * N / 2 + j], row, j * 4);
330 
331  // Unpack the row to obtain two rows of the output. If we have the out
332  // sub-tiles transposed we obtain two consecutive output rows by
333  // separating even and odd elements, i.e. a simple deinterleave.
334  // Otherwise, the interleave is by pairs.
335  Value out0, out1;
336  if (mmlaOp == MMLA::MixedSwapped) {
337  auto tmp = rewriter.create<vector::DeinterleaveOp>(loc, row);
338  out0 = tmp.getRes1();
339  out1 = tmp.getRes2();
340  } else {
341  // Deinterleave by pairs.
342  auto row64 = rewriter.create<vector::BitCastOp>(loc, accRowX264Ty, row);
343  auto deintr64 = rewriter.create<vector::DeinterleaveOp>(loc, row64);
344 
345  // Bitcast back into 32-bit elements and insert into the result.
346  out0 = rewriter.create<vector::BitCastOp>(loc, accRowTy,
347  deintr64.getRes1());
348  out1 = rewriter.create<vector::BitCastOp>(loc, accRowTy,
349  deintr64.getRes2());
350  }
351  result = rewriter.create<vector::InsertOp>(loc, out0, result, i * 2);
352  result = rewriter.create<vector::InsertOp>(loc, out1, result, i * 2 + 1);
353  }
354 
355  rewriter.replaceOp(op, result);
356  return success();
357  }
358 };
359 
360 } // namespace
361 
364  MLIRContext *context = patterns.getContext();
365  patterns.add<LowerContractionToSVEI8MMPattern>(context, /*benefit=*/2);
366 }
static AffineMap getMultiDimMapWithTargets(unsigned numDims, ArrayRef< unsigned > targets, MLIRContext *context)
Returns an affine map with numDims input dimensions and results specified by targets.
Definition: AffineMap.cpp:276
IntegerType getI64Type()
Definition: Builders.cpp:64
IntegerType getI32Type()
Definition: Builders.cpp:62
IntegerType getI8Type()
Definition: Builders.cpp:58
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:748
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:681
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateLowerContractionToSVEI8MMPatternPatterns(RewritePatternSet &patterns)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.