MLIR 22.0.0git
LowerContractToNeonPatterns.cpp
Go to the documentation of this file.
1//===- LowerContractToNeonPatterns.cpp - Contract to I8MM/BF16 --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements lowering patterns from vector.contract to operations
10// that map to instructions from the Neon FEAT_I8MM extension.
11//
12// TODO: There may be opportunities to unify this with a similar pattern
13// for SVE. See:
14// https://github.com/llvm/llvm-project/issues/145559
15// LowerContractToSVEPatterns.cpp
16//
17//===----------------------------------------------------------------------===//
18
25#include "mlir/IR/AffineMap.h"
27
28#define DEBUG_TYPE "lower-contract-to-arm-neon"
29
30using namespace mlir;
31using namespace mlir::arm_neon;
32
33namespace {
34/// Get the operand of a `vector.contract`. This function is intended to
35/// abstract away from the particular way a value is extended before feeding it
36/// into the `vector.contract` - via zero-extend or an explicit or implicit
37/// sign-extend (for implicit sign-extension see `vector.contract`
38/// documentation).
39///
40/// The template parameter `Op` indicates the extension operation (explicit or
41/// implicit) for which we are checking.
42///
43// Return success only for extensions from `iN` (N <= 8) to `i32`.
44template <typename Op>
45std::optional<Value> getExtOperand(Value v) {
46
47 static_assert(llvm::is_one_of<Op, arith::ExtSIOp, arith::ExtUIOp>::value,
48 "Must be instantiated with either sign- or zero- extension op");
49
50 // If the operand is not defined by an explicit extend operation of the
51 // accepted operation type allow for an implicit sign-extension.
52 auto extOp = v.getDefiningOp<Op>();
53 if (!extOp) {
54 if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
55 auto eltTy = cast<VectorType>(v.getType()).getElementType();
56 if (!eltTy.isSignlessInteger() || eltTy.getIntOrFloatBitWidth() > 8)
57 return {};
58 return v;
59 }
60 return {};
61 }
62
63 // If the operand is defined by an explicit extend operation of the accepted
64 // operation type, check it's extended from `iN` (N <= 8) to `i32`.
65 auto inOp = extOp.getIn();
66 auto inTy = dyn_cast<VectorType>(inOp.getType());
67 if (!inTy)
68 return {};
69 auto inEltTy = inTy.getElementType();
70 if (!inEltTy.isSignlessInteger() || inEltTy.getIntOrFloatBitWidth() > 8)
71 return {};
72
73 auto outTy = dyn_cast<VectorType>(extOp.getType());
74 if (!(outTy && outTy.getElementType().isSignlessInteger(32)))
75 return {};
76
77 return inOp;
78}
79
80/// Helper function to extend a vector with elements iN, N < 8 to
81/// a vector of i8. Do sign extension if the parameter `signExt` is true,
82/// zero extension otherwise.
83Value extendSmallIntVector(Location loc, VectorType srcTy, Value val,
84 bool signExt, PatternRewriter &rewriter) {
85 Type targetTy = srcTy.clone(rewriter.getI8Type());
86 return signExt ? rewriter.createOrFold<arith::ExtSIOp>(loc, targetTy, val)
87 : rewriter.createOrFold<arith::ExtUIOp>(loc, targetTy, val);
88}
89
90class VectorContractRewriter {
91protected:
92 // Designate the operation (resp. instruction) used to do sub-tile matrix
93 // multiplications.
94 enum class MMLA {
95 Nop,
96 SignedInt, // smmla
97 UnsignedInt, // ummla
98 MixedInt, // usmmla
99 Bfloat // bfmmla
100 };
101
102 // Lower-level operation to be emitted.
103 MMLA mmlaOp = MMLA::Nop;
104
105 // Indicate if the operands for the ArmNeon dialect operation need to be
106 // swapped. Currently this is needed in order to emulate an "summla"
107 // operation.
108 bool swapOperands = false;
109
110 // The operand tiles. These are not necessarily the operands of
111 // `vector.contract`, for example they could be operands to `arith.extsi`
112 // that is in turn fed into `vector.contract`.
113 Value lhs;
114 Value rhs;
115 Value acc;
116
117 // The dimensions logically corresponding to matrix multiplication of
118 // MxK * KxN -> MxN. The operands and the result do not necessarily have these
119 // shapes, for example RHS could be NxK with a transposing indexing map.
120 int64_t dimM = 0;
121 int64_t dimN = 0;
122 int64_t dimK = 0;
123
124 // Unroll iteration bounds. See documentaiton for `StaticTileOffsetRange`.
125 SmallVector<int64_t> iterationBounds;
126
127 // Sub-tile shape. The algorithm handles operand shapes, which are multiples
128 // of this shape.
129 SmallVector<int64_t> subTileShape;
130
131 // Create the matrix multiply and accumulate operation according to `mmlaOp`.
132 Value createMMLA(PatternRewriter &rewriter, Location loc, Value acc,
133 Value lhs, Value rhs) {
134
135 if (swapOperands)
136 std::swap(lhs, rhs);
137 switch (mmlaOp) {
138 case MMLA::SignedInt:
139 return rewriter.createOrFold<arm_neon::SmmlaOp>(loc, acc.getType(), acc,
140 lhs, rhs);
141 case MMLA::UnsignedInt:
142 return rewriter.createOrFold<arm_neon::UmmlaOp>(loc, acc.getType(), acc,
143 lhs, rhs);
144 case MMLA::MixedInt:
145 return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, acc.getType(), acc,
146 lhs, rhs);
147 case MMLA::Bfloat:
148 return arm_neon::BfmmlaOp::create(rewriter, loc, acc.getType(), acc, lhs,
149 rhs);
150 case MMLA::Nop:
151 llvm_unreachable("Uninitialized operation type");
152 }
153 }
154
155 // Check common preconditions for applying the patterns and initialize
156 // logical dimensions.
157 LogicalResult matchAndInit(vector::ContractionOp op,
158 PatternRewriter &rewriter) {
159 // Check iterator types for matrix multiplication.
160 SmallVector<vector::IteratorType> itTypes = op.getIteratorTypesArray();
161 if ((itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
162 itTypes[1] != vector::IteratorType::parallel ||
163 itTypes[2] != vector::IteratorType::reduction) &&
164 (itTypes.size() != 2 || itTypes[0] != vector::IteratorType::parallel ||
165 itTypes[1] != vector::IteratorType::reduction))
166 return rewriter.notifyMatchFailure(
167 op, "iterator types do not correspond to matrix multiplication");
168
169 // Avoid 0-D vectors and 1-D rhs:
170 VectorType lhsType = op.getLhsType();
171 VectorType rhsType = op.getRhsType();
172 if (!lhsType.hasRank() || !rhsType.hasRank() || lhsType.getRank() > 2 ||
173 rhsType.getRank() != 2)
174 return rewriter.notifyMatchFailure(op, "Invalid operand rank");
175
176 // This codegen does not work for scalable vectors. Return failure so this
177 // pattern is not accidentally chosen over patterns that lower to ArmSVE.
178 if (lhsType.isScalable() || rhsType.isScalable())
179 return rewriter.notifyMatchFailure(op,
180 "Not applicable to scalable vectors");
181
182 // Initialize dimensions and check for a matching K dimension.
183 dimM = lhsType.getDimSize(0);
184 dimN = rhsType.getDimSize(0);
185 dimK = rhsType.getDimSize(1);
186
187 int64_t lhsDimK;
188 if (lhsType.getRank() == 1) {
189 dimM = 1;
190 lhsDimK = lhsType.getDimSize(0);
191 } else {
192 lhsDimK = lhsType.getDimSize(1);
193 }
194
195 if (lhsDimK != dimK)
196 return rewriter.notifyMatchFailure(op, "Dimensions mismatch");
197
198 return success();
199 }
200
201public:
202 void lower(vector::ContractionOp op, PatternRewriter &rewriter) {
203 // Create some convenience types.
204 auto inputElementType = cast<ShapedType>(lhs.getType()).getElementType();
205 auto accElementType = cast<ShapedType>(acc.getType()).getElementType();
206 auto inputExpandedType =
207 VectorType::get({2, subTileShape.back()}, inputElementType);
208 auto outputExpandedType = VectorType::get({2, 2}, accElementType);
209
210 // One-dimensional representation of logical sub-tiles as required by the
211 // ArmNeon ops.
212 auto collapsedInputType =
213 VectorType::get(inputExpandedType.getNumElements(), inputElementType);
214 auto collapsedOutputType =
215 VectorType::get(outputExpandedType.getNumElements(), accElementType);
216
217 // Get indexing maps for a more concise/convenient access.
218 auto indexingMaps = op.getIndexingMapsArray();
219 AffineMap &lhsPermutationMap = indexingMaps[0];
220 AffineMap &rhsPermutationMap = indexingMaps[1];
221 AffineMap &accPermutationMap = indexingMaps[2];
222
223 Location loc = op.getLoc();
224
225 // Initial accumulator for the final result. This is the un-tiled result if
226 // tiling is done.
227 Value result =
228 arith::ConstantOp::create(rewriter, loc, op.getResultType(),
229 rewriter.getZeroAttr(op.getResultType()));
230
231 SmallVector<int64_t, 3> loopOrder = {0, 1};
232 if (iterationBounds.size() == 3)
233 loopOrder.push_back(2);
234
235 // Keep track of the previous accumulator when tiling over K.
236 Value kAcc;
237 for (SmallVector<int64_t> offsets :
238 StaticTileOffsetRange(iterationBounds, subTileShape, loopOrder)) {
239 // Helper to compute the new shape of each operand and extract the slice.
240 auto extractOperand = [&](Value operand, AffineMap permutationMap,
241 ArrayRef<int64_t> operandOffsets) {
243 permutationMap, ArrayRef<int64_t>(subTileShape));
244 SmallVector<int64_t> operandStrides(operandOffsets.size(), 1);
245 return rewriter.createOrFold<vector::ExtractStridedSliceOp>(
246 loc, operand, operandOffsets, operandShape, operandStrides);
247 };
248
249 // Extract tiled lhs, rhs, and acc
250 SmallVector<int64_t> lhsOffsets =
251 applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
252 Value tiledLhs = extractOperand(lhs, lhsPermutationMap, lhsOffsets);
253 SmallVector<int64_t> rhsOffsets =
254 applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
255 Value tiledRhs = extractOperand(rhs, rhsPermutationMap, rhsOffsets);
256 SmallVector<int64_t> accOffsets =
257 applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
258 Value tiledAcc = extractOperand(acc, accPermutationMap, accOffsets);
259
260 // With vecmat, tiled LHS and ACC will contain only one of 2 necessary
261 // rows along dimM. Expand their shapes to match the ArmNeon op.
262 if (dimM == 1) {
263 auto expandRowVector = [&](Value tiledOperand,
264 VectorType expandedTypeType) {
265 auto emptyOperand =
266 arith::ConstantOp::create(rewriter, loc, expandedTypeType,
267 rewriter.getZeroAttr(expandedTypeType));
268 SmallVector<int64_t> offsets(
269 cast<ShapedType>(emptyOperand.getType()).getRank(), 0);
270 SmallVector<int64_t> strides(
271 cast<ShapedType>(tiledOperand.getType()).getRank(), 1);
272 return rewriter.createOrFold<vector::InsertStridedSliceOp>(
273 loc, tiledOperand, emptyOperand, offsets, strides);
274 };
275 tiledLhs = expandRowVector(tiledLhs, inputExpandedType);
276 tiledAcc = expandRowVector(tiledAcc, outputExpandedType);
277 }
278
279 // Transpose ACC if doing signed by unsigned multiplication, because we're
280 // using the instruction for unsigned by signed multiplication with
281 // reversed operands.
282 if (swapOperands)
283 tiledAcc = vector::TransposeOp::create(rewriter, loc, tiledAcc,
284 ArrayRef<int64_t>({1, 0}));
285
286 // Collapse tiled operands to 1D vectors required by the ArmNeon ops
287 auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
288 tiledLhs.getLoc(), collapsedInputType, tiledLhs);
289 auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
290 tiledRhs.getLoc(), collapsedInputType, tiledRhs);
291
292 bool initialKAcc = offsets.back() == 0;
293 Value collapsedRes;
294 if (!initialKAcc) {
295 collapsedRes = kAcc;
296 } else {
297 collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
298 tiledAcc.getLoc(), collapsedOutputType, tiledAcc);
299 }
300
301 // Insert contract op
302 kAcc =
303 createMMLA(rewriter, loc, collapsedRes, collapsedLhs, collapsedRhs);
304
305 // Reshape output back to 2D
306 Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
307 kAcc.getLoc(), tiledAcc.getType(), kAcc);
308
309 // Because of the reversed operands the result is obtained transposed.
310 // Transpose it back,
311 if (swapOperands)
312 tiledRes = vector::TransposeOp::create(rewriter, loc, tiledRes,
313 ArrayRef<int64_t>({1, 0}));
314
315 // With vecmat, only one row of tiled ACC can be inserted into the final
316 // result
317 if (dimM == 1)
318 tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
319
320 // Insert the tiled result back into the non tiled result of the
321 // contract op.
322 SmallVector<int64_t> strides(
323 cast<ShapedType>(tiledRes.getType()).getRank(), 1);
324 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
325 loc, tiledRes, result, accOffsets, strides);
326 }
327
328 rewriter.replaceOp(op, result);
329 }
330};
331
332class VectorContractRewriterI8MM : public VectorContractRewriter {
333public:
334 LogicalResult matchAndInit(vector::ContractionOp op,
335 PatternRewriter &rewriter) {
336 if (failed(VectorContractRewriter::matchAndInit(op, rewriter)))
337 return failure();
338
339 // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
340 // tiling.
341 if ((dimM != 1 && dimM % 2 != 0) || dimN % 2 != 0 || dimK % 8 != 0)
342 return rewriter.notifyMatchFailure(op, "Unsupported operand shapes");
343
344 // Check inputs are sign-/zero- extensions from iN (N <= 8) to i32. Get the
345 // values before the extension. All four signed/unsigned combinations for
346 // input operands are supported, but they are lowered to different
347 // operations. Determine which is the appropriate operation to lower to.
348 mmlaOp = MMLA::SignedInt;
349 auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
350 if (!maybeLhs) {
351 mmlaOp = MMLA::UnsignedInt;
352 maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
353 }
354 if (!maybeLhs)
355 return rewriter.notifyMatchFailure(
356 op, "LHS is not a sign- or zero- extended iN, N <= 8");
357
358 auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs());
359 if (maybeRhs) {
360 if (mmlaOp == MMLA::UnsignedInt)
361 mmlaOp = MMLA::MixedInt;
362 } else {
363 if (mmlaOp == MMLA::SignedInt) {
364 mmlaOp = MMLA::MixedInt;
365 swapOperands = true;
366 }
367 maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs());
368 }
369
370 if (!maybeRhs)
371 return rewriter.notifyMatchFailure(
372 op, "RHS is not a sign- or zero- extended iN, N <= 8");
373
374 lhs = *maybeLhs;
375 rhs = *maybeRhs;
376 acc = op.getAcc();
377
378 // Extend inputs from iN, N < 8 to i8.
379 Location loc = op.getLoc();
380 auto lhsExtInType = cast<VectorType>(lhs.getType());
381 if (lhsExtInType.getElementTypeBitWidth() < 8)
382 lhs = extendSmallIntVector(loc, lhsExtInType, lhs,
383 /* signExt */
384 (mmlaOp == MMLA::SignedInt ||
385 (mmlaOp == MMLA::MixedInt && !swapOperands)),
386 rewriter);
387
388 auto rhsExtInType = cast<VectorType>(rhs.getType());
389 if (rhsExtInType.getElementTypeBitWidth() < 8)
390 rhs = extendSmallIntVector(loc, rhsExtInType, rhs,
391 /* signExt */
392 (mmlaOp == MMLA::SignedInt ||
393 (mmlaOp == MMLA::MixedInt && swapOperands)),
394 rewriter);
395
396 // Initialize parameters for unrolling.
397 iterationBounds = *op.getShapeForUnroll();
398 if (iterationBounds.size() == 3)
399 subTileShape = SmallVector<int64_t>({dimM == 1 ? 1 : 2, 2, 8});
400 else
401 subTileShape = SmallVector<int64_t>({2, 8});
402
403 return success();
404 }
405};
406
407class VectorContractRewriterBFMMLA : public VectorContractRewriter {
408public:
409 LogicalResult matchAndInit(vector::ContractionOp op,
410 PatternRewriter &rewriter) {
411
412 if (failed(VectorContractRewriter::matchAndInit(op, rewriter)))
413 return failure();
414
415 // Unrolling patterns can handle any [2, 2, 4] shaped multiple of inputs for
416 // tiling.
417 if ((dimM != 1 && dimM % 2 != 0) || dimN % 2 != 0 || dimK % 4 != 0)
418 return rewriter.notifyMatchFailure(op, "Unsupported operand shapes");
419
420 // Check the output is a vector of Float32 elements.
421 auto outTy = dyn_cast<VectorType>(op.getResultType());
422 if (!outTy || outTy.getElementType() != rewriter.getF32Type())
423 return rewriter.notifyMatchFailure(op,
424 "output type is not a vector of f32");
425
426 // Check the inputs are vectors of BFloat16 elements.
427 if (op.getLhsType().getElementType() != rewriter.getBF16Type())
428 return rewriter.notifyMatchFailure(op,
429 "input type is not a vector of bf16");
430
431 mmlaOp = MMLA::Bfloat;
432 swapOperands = false;
433 lhs = op.getLhs();
434 rhs = op.getRhs();
435 acc = op.getAcc();
436
437 // Initialize parameters for unrolling.
438 iterationBounds = *op.getShapeForUnroll();
439 if (iterationBounds.size() == 3)
440 subTileShape = SmallVector<int64_t>({dimM == 1 ? 1 : 2, 2, 4});
441 else
442 subTileShape = SmallVector<int64_t>({2, 4});
443
444 return success();
445 }
446};
447
448/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
449/// any vector.contract into multiple smmla instructions with unrolling so long
450/// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
451/// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is
452/// necessary, a single smmla instruction is emitted.
453class LowerContractionToNeonI8MMPattern
454 : public OpRewritePattern<vector::ContractionOp> {
455public:
457 LogicalResult matchAndRewrite(vector::ContractionOp op,
458 PatternRewriter &rewriter) const override {
459
460 VectorContractRewriterI8MM vcr;
461 if (failed(vcr.matchAndInit(op, rewriter)))
462 return failure();
463 vcr.lower(op, rewriter);
464
465 return success();
466 }
467};
468
469class LowerContractionToNeonBFMMLAPattern
470 : public OpRewritePattern<vector::ContractionOp> {
471public:
473 LogicalResult matchAndRewrite(vector::ContractionOp op,
474 PatternRewriter &rewriter) const override {
475
476 VectorContractRewriterBFMMLA vcr;
477 if (failed(vcr.matchAndInit(op, rewriter)))
478 return failure();
479 vcr.lower(op, rewriter);
480
481 return success();
482 }
483};
484
485} // namespace
486
489 MLIRContext *context = patterns.getContext();
490 patterns.add<LowerContractionToNeonI8MMPattern>(context, /*benefit=*/2);
491}
492
495 MLIRContext *context = patterns.getContext();
496 patterns.add<LowerContractionToNeonBFMMLAPattern>(context, /*benefit=*/2);
497}
return success()
lhs
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
FloatType getF32Type()
Definition Builders.cpp:43
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:324
FloatType getBF16Type()
Definition Builders.cpp:37
IntegerType getI8Type()
Definition Builders.cpp:59
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
This provides public APIs that all operations should have.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
void populateLowerContractionToNeonI8MMPatterns(RewritePatternSet &patterns)
void populateLowerContractionToNeonBFMMLAPatterns(RewritePatternSet &patterns)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition AffineMap.h:675
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...