MLIR 23.0.0git
LowerContractToNeonPatterns.cpp
Go to the documentation of this file.
1//===- LowerContractToNeonPatterns.cpp - Contract to I8MM/BF16 --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements lowering patterns from vector.contract to operations
10// that map to instructions from the Neon FEAT_I8MM extension.
11//
12// TODO: There may be opportunities to unify this with a similar pattern
13// for SVE. See:
14// https://github.com/llvm/llvm-project/issues/145559
15// LowerContractToSVEPatterns.cpp
16//
17//===----------------------------------------------------------------------===//
18
25#include "mlir/IR/AffineMap.h"
27
28#define DEBUG_TYPE "lower-contract-to-arm-neon"
29
30using namespace mlir;
31using namespace mlir::arm_neon;
32
33namespace {
34/// Get the operand of a `vector.contract`. This function is intended to
35/// abstract away from the particular way a value is extended before feeding it
36/// into the `vector.contract` - via zero-extend or an explicit or implicit
37/// sign-extend (for implicit sign-extension see `vector.contract`
38/// documentation).
39///
40/// The template parameter `Op` indicates the extension operation (explicit or
41/// implicit) for which we are checking.
42///
43// Return success only for extensions from `iN` (N <= 8) to `i32`.
44template <typename Op>
45std::optional<Value> getExtOperand(Value v) {
46
47 static_assert(llvm::is_one_of<Op, arith::ExtSIOp, arith::ExtUIOp>::value,
48 "Must be instantiated with either sign- or zero- extension op");
49
50 // If the operand is not defined by an explicit extend operation of the
51 // accepted operation type allow for an implicit sign-extension.
52 auto extOp = v.getDefiningOp<Op>();
53 if (!extOp) {
54 if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
55 auto eltTy = cast<VectorType>(v.getType()).getElementType();
56 if (!eltTy.isSignlessInteger() || eltTy.getIntOrFloatBitWidth() > 8)
57 return {};
58 return v;
59 }
60 return {};
61 }
62
63 // If the operand is defined by an explicit extend operation of the accepted
64 // operation type, check it's extended from `iN` (N <= 8) to `i32`.
65 auto inOp = extOp.getIn();
66 auto inTy = dyn_cast<VectorType>(inOp.getType());
67 if (!inTy)
68 return {};
69 auto inEltTy = inTy.getElementType();
70 if (!inEltTy.isSignlessInteger() || inEltTy.getIntOrFloatBitWidth() > 8)
71 return {};
72
73 auto outTy = dyn_cast<VectorType>(extOp.getType());
74 if (!(outTy && outTy.getElementType().isSignlessInteger(32)))
75 return {};
76
77 return inOp;
78}
79
80/// Helper function to extend a vector with elements iN, N < 8 to
81/// a vector of i8. Do sign extension if the parameter `signExt` is true,
82/// zero extension otherwise.
83Value extendSmallIntVector(Location loc, VectorType srcTy, Value val,
84 bool signExt, PatternRewriter &rewriter) {
85 Type targetTy = srcTy.clone(rewriter.getI8Type());
86 return signExt ? rewriter.createOrFold<arith::ExtSIOp>(loc, targetTy, val)
87 : rewriter.createOrFold<arith::ExtUIOp>(loc, targetTy, val);
88}
89
90class VectorContractRewriter {
91protected:
92 // Designate the operation (resp. instruction) used to do sub-tile matrix
93 // multiplications.
94 enum class MMLA {
95 Nop,
96 SignedInt, // smmla
97 UnsignedInt, // ummla
98 MixedInt, // usmmla
99 Bfloat // bfmmla
100 };
101
102 // Lower-level operation to be emitted.
103 MMLA mmlaOp = MMLA::Nop;
104
105 // Indicate if the operands for the ArmNeon dialect operation need to be
106 // swapped. Currently this is needed in order to emulate an "summla"
107 // operation.
108 bool swapOperands = false;
109
110 // The operand tiles. These are not necessarily the operands of
111 // `vector.contract`, for example they could be operands to `arith.extsi`
112 // that is in turn fed into `vector.contract`.
113 Value lhs;
114 Value rhs;
115 Value acc;
116
117 // The dimensions logically corresponding to matrix multiplication of
118 // MxK * KxN -> MxN. The operands and the result do not necessarily have these
119 // shapes, for example RHS could be NxK with a transposing indexing map.
120 int64_t dimM = 0;
121 int64_t dimN = 0;
122 int64_t dimK = 0;
123
124 // Unroll iteration bounds. See documentaiton for `StaticTileOffsetRange`.
125 SmallVector<int64_t> iterationBounds;
126
127 // Sub-tile shape. The algorithm handles operand shapes, which are multiples
128 // of this shape.
129 SmallVector<int64_t> subTileShape;
130
131 // Create the matrix multiply and accumulate operation according to `mmlaOp`.
132 Value createMMLA(PatternRewriter &rewriter, Location loc, Value acc,
133 Value lhs, Value rhs) {
134
135 if (swapOperands)
136 std::swap(lhs, rhs);
137 switch (mmlaOp) {
138 case MMLA::SignedInt:
139 return rewriter.createOrFold<arm_neon::SmmlaOp>(loc, acc.getType(), acc,
140 lhs, rhs);
141 case MMLA::UnsignedInt:
142 return rewriter.createOrFold<arm_neon::UmmlaOp>(loc, acc.getType(), acc,
143 lhs, rhs);
144 case MMLA::MixedInt:
145 return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, acc.getType(), acc,
146 lhs, rhs);
147 case MMLA::Bfloat:
148 return arm_neon::BfmmlaOp::create(rewriter, loc, acc.getType(), acc, lhs,
149 rhs);
150 case MMLA::Nop:
151 llvm_unreachable("Uninitialized operation type");
152 }
153 llvm_unreachable("Unknown MMLA");
154 }
155
156 // Check common preconditions for applying the patterns and initialize
157 // logical dimensions.
158 LogicalResult matchAndInit(vector::ContractionOp op,
159 PatternRewriter &rewriter) {
160 // Check iterator types for matrix multiplication.
161 SmallVector<vector::IteratorType> itTypes = op.getIteratorTypesArray();
162 if ((itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
163 itTypes[1] != vector::IteratorType::parallel ||
164 itTypes[2] != vector::IteratorType::reduction) &&
165 (itTypes.size() != 2 || itTypes[0] != vector::IteratorType::parallel ||
166 itTypes[1] != vector::IteratorType::reduction))
167 return rewriter.notifyMatchFailure(
168 op, "iterator types do not correspond to matrix multiplication");
169
170 // Avoid 0-D vectors and 1-D rhs:
171 VectorType lhsType = op.getLhsType();
172 VectorType rhsType = op.getRhsType();
173 if (!lhsType.hasRank() || !rhsType.hasRank() || lhsType.getRank() > 2 ||
174 rhsType.getRank() != 2)
175 return rewriter.notifyMatchFailure(op, "Invalid operand rank");
176
177 // This codegen does not work for scalable vectors. Return failure so this
178 // pattern is not accidentally chosen over patterns that lower to ArmSVE.
179 if (lhsType.isScalable() || rhsType.isScalable())
180 return rewriter.notifyMatchFailure(op,
181 "Not applicable to scalable vectors");
182
183 // Initialize dimensions and check for a matching K dimension.
184 dimM = lhsType.getDimSize(0);
185 dimN = rhsType.getDimSize(0);
186 dimK = rhsType.getDimSize(1);
187
188 int64_t lhsDimK;
189 if (lhsType.getRank() == 1) {
190 dimM = 1;
191 lhsDimK = lhsType.getDimSize(0);
192 } else {
193 lhsDimK = lhsType.getDimSize(1);
194 }
195
196 if (lhsDimK != dimK)
197 return rewriter.notifyMatchFailure(op, "Dimensions mismatch");
198
199 return success();
200 }
201
202public:
203 void lower(vector::ContractionOp op, PatternRewriter &rewriter) {
204 // Create some convenience types.
205 auto inputElementType = cast<ShapedType>(lhs.getType()).getElementType();
206 auto accElementType = cast<ShapedType>(acc.getType()).getElementType();
207 auto inputExpandedType =
208 VectorType::get({2, subTileShape.back()}, inputElementType);
209 auto outputExpandedType = VectorType::get({2, 2}, accElementType);
210
211 // One-dimensional representation of logical sub-tiles as required by the
212 // ArmNeon ops.
213 auto collapsedInputType =
214 VectorType::get(inputExpandedType.getNumElements(), inputElementType);
215 auto collapsedOutputType =
216 VectorType::get(outputExpandedType.getNumElements(), accElementType);
217
218 // Get indexing maps for a more concise/convenient access.
219 auto indexingMaps = op.getIndexingMapsArray();
220 AffineMap &lhsPermutationMap = indexingMaps[0];
221 AffineMap &rhsPermutationMap = indexingMaps[1];
222 AffineMap &accPermutationMap = indexingMaps[2];
223
224 Location loc = op.getLoc();
225
226 // Initial accumulator for the final result. This is the un-tiled result if
227 // tiling is done.
228 Value result =
229 arith::ConstantOp::create(rewriter, loc, op.getResultType(),
230 rewriter.getZeroAttr(op.getResultType()));
231
232 SmallVector<int64_t, 3> loopOrder = {0, 1};
233 if (iterationBounds.size() == 3)
234 loopOrder.push_back(2);
235
236 // Keep track of the previous accumulator when tiling over K.
237 Value kAcc;
238 for (SmallVector<int64_t> offsets :
239 StaticTileOffsetRange(iterationBounds, subTileShape, loopOrder)) {
240 // Helper to compute the new shape of each operand and extract the slice.
241 auto extractOperand = [&](Value operand, AffineMap permutationMap,
242 ArrayRef<int64_t> operandOffsets) {
244 permutationMap, ArrayRef<int64_t>(subTileShape));
245 SmallVector<int64_t> operandStrides(operandOffsets.size(), 1);
246 return rewriter.createOrFold<vector::ExtractStridedSliceOp>(
247 loc, operand, operandOffsets, operandShape, operandStrides);
248 };
249
250 // Extract tiled lhs, rhs, and acc
251 SmallVector<int64_t> lhsOffsets =
252 applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
253 Value tiledLhs = extractOperand(lhs, lhsPermutationMap, lhsOffsets);
254 SmallVector<int64_t> rhsOffsets =
255 applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
256 Value tiledRhs = extractOperand(rhs, rhsPermutationMap, rhsOffsets);
257 SmallVector<int64_t> accOffsets =
258 applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
259 Value tiledAcc = extractOperand(acc, accPermutationMap, accOffsets);
260
261 // With vecmat, tiled LHS and ACC will contain only one of 2 necessary
262 // rows along dimM. Expand their shapes to match the ArmNeon op.
263 if (dimM == 1) {
264 auto expandRowVector = [&](Value tiledOperand,
265 VectorType expandedTypeType) {
266 auto emptyOperand =
267 arith::ConstantOp::create(rewriter, loc, expandedTypeType,
268 rewriter.getZeroAttr(expandedTypeType));
269 SmallVector<int64_t> offsets(
270 cast<ShapedType>(emptyOperand.getType()).getRank(), 0);
271 SmallVector<int64_t> strides(
272 cast<ShapedType>(tiledOperand.getType()).getRank(), 1);
273 return rewriter.createOrFold<vector::InsertStridedSliceOp>(
274 loc, tiledOperand, emptyOperand, offsets, strides);
275 };
276 tiledLhs = expandRowVector(tiledLhs, inputExpandedType);
277 tiledAcc = expandRowVector(tiledAcc, outputExpandedType);
278 }
279
280 // Transpose ACC if doing signed by unsigned multiplication, because we're
281 // using the instruction for unsigned by signed multiplication with
282 // reversed operands.
283 if (swapOperands)
284 tiledAcc = vector::TransposeOp::create(rewriter, loc, tiledAcc,
285 ArrayRef<int64_t>({1, 0}));
286
287 // Collapse tiled operands to 1D vectors required by the ArmNeon ops
288 auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
289 tiledLhs.getLoc(), collapsedInputType, tiledLhs);
290 auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
291 tiledRhs.getLoc(), collapsedInputType, tiledRhs);
292
293 bool initialKAcc = offsets.back() == 0;
294 Value collapsedRes;
295 if (!initialKAcc) {
296 collapsedRes = kAcc;
297 } else {
298 collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
299 tiledAcc.getLoc(), collapsedOutputType, tiledAcc);
300 }
301
302 // Insert contract op
303 kAcc =
304 createMMLA(rewriter, loc, collapsedRes, collapsedLhs, collapsedRhs);
305
306 // Reshape output back to 2D
307 Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
308 kAcc.getLoc(), tiledAcc.getType(), kAcc);
309
310 // Because of the reversed operands the result is obtained transposed.
311 // Transpose it back,
312 if (swapOperands)
313 tiledRes = vector::TransposeOp::create(rewriter, loc, tiledRes,
314 ArrayRef<int64_t>({1, 0}));
315
316 // With vecmat, only one row of tiled ACC can be inserted into the final
317 // result
318 if (dimM == 1)
319 tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
320
321 // Insert the tiled result back into the non tiled result of the
322 // contract op.
323 SmallVector<int64_t> strides(
324 cast<ShapedType>(tiledRes.getType()).getRank(), 1);
325 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
326 loc, tiledRes, result, accOffsets, strides);
327 }
328
329 rewriter.replaceOp(op, result);
330 }
331};
332
333class VectorContractRewriterI8MM : public VectorContractRewriter {
334public:
335 LogicalResult matchAndInit(vector::ContractionOp op,
336 PatternRewriter &rewriter) {
337 if (failed(VectorContractRewriter::matchAndInit(op, rewriter)))
338 return failure();
339
340 // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
341 // tiling.
342 if ((dimM != 1 && dimM % 2 != 0) || dimN % 2 != 0 || dimK % 8 != 0)
343 return rewriter.notifyMatchFailure(op, "Unsupported operand shapes");
344
345 // Check inputs are sign-/zero- extensions from iN (N <= 8) to i32. Get the
346 // values before the extension. All four signed/unsigned combinations for
347 // input operands are supported, but they are lowered to different
348 // operations. Determine which is the appropriate operation to lower to.
349 mmlaOp = MMLA::SignedInt;
350 auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
351 if (!maybeLhs) {
352 mmlaOp = MMLA::UnsignedInt;
353 maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
354 }
355 if (!maybeLhs)
356 return rewriter.notifyMatchFailure(
357 op, "LHS is not a sign- or zero- extended iN, N <= 8");
358
359 auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs());
360 if (maybeRhs) {
361 if (mmlaOp == MMLA::UnsignedInt)
362 mmlaOp = MMLA::MixedInt;
363 } else {
364 if (mmlaOp == MMLA::SignedInt) {
365 mmlaOp = MMLA::MixedInt;
366 swapOperands = true;
367 }
368 maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs());
369 }
370
371 if (!maybeRhs)
372 return rewriter.notifyMatchFailure(
373 op, "RHS is not a sign- or zero- extended iN, N <= 8");
374
375 lhs = *maybeLhs;
376 rhs = *maybeRhs;
377 acc = op.getAcc();
378
379 // Extend inputs from iN, N < 8 to i8.
380 Location loc = op.getLoc();
381 auto lhsExtInType = cast<VectorType>(lhs.getType());
382 if (lhsExtInType.getElementTypeBitWidth() < 8)
383 lhs = extendSmallIntVector(loc, lhsExtInType, lhs,
384 /* signExt */
385 (mmlaOp == MMLA::SignedInt ||
386 (mmlaOp == MMLA::MixedInt && !swapOperands)),
387 rewriter);
388
389 auto rhsExtInType = cast<VectorType>(rhs.getType());
390 if (rhsExtInType.getElementTypeBitWidth() < 8)
391 rhs = extendSmallIntVector(loc, rhsExtInType, rhs,
392 /* signExt */
393 (mmlaOp == MMLA::SignedInt ||
394 (mmlaOp == MMLA::MixedInt && swapOperands)),
395 rewriter);
396
397 // Initialize parameters for unrolling.
398 iterationBounds = *op.getShapeForUnroll();
399 if (iterationBounds.size() == 3)
400 subTileShape = SmallVector<int64_t>({dimM == 1 ? 1 : 2, 2, 8});
401 else
402 subTileShape = SmallVector<int64_t>({2, 8});
403
404 return success();
405 }
406};
407
408class VectorContractRewriterBFMMLA : public VectorContractRewriter {
409public:
410 LogicalResult matchAndInit(vector::ContractionOp op,
411 PatternRewriter &rewriter) {
412
413 if (failed(VectorContractRewriter::matchAndInit(op, rewriter)))
414 return failure();
415
416 // Unrolling patterns can handle any [2, 2, 4] shaped multiple of inputs for
417 // tiling.
418 if ((dimM != 1 && dimM % 2 != 0) || dimN % 2 != 0 || dimK % 4 != 0)
419 return rewriter.notifyMatchFailure(op, "Unsupported operand shapes");
420
421 // Check the output is a vector of Float32 elements.
422 auto outTy = dyn_cast<VectorType>(op.getResultType());
423 if (!outTy || outTy.getElementType() != rewriter.getF32Type())
424 return rewriter.notifyMatchFailure(op,
425 "output type is not a vector of f32");
426
427 // Check the inputs are vectors of BFloat16 elements.
428 if (op.getLhsType().getElementType() != rewriter.getBF16Type())
429 return rewriter.notifyMatchFailure(op,
430 "input type is not a vector of bf16");
431
432 mmlaOp = MMLA::Bfloat;
433 swapOperands = false;
434 lhs = op.getLhs();
435 rhs = op.getRhs();
436 acc = op.getAcc();
437
438 // Initialize parameters for unrolling.
439 iterationBounds = *op.getShapeForUnroll();
440 if (iterationBounds.size() == 3)
441 subTileShape = SmallVector<int64_t>({dimM == 1 ? 1 : 2, 2, 4});
442 else
443 subTileShape = SmallVector<int64_t>({2, 4});
444
445 return success();
446 }
447};
448
449/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
450/// any vector.contract into multiple smmla instructions with unrolling so long
451/// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
452/// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is
453/// necessary, a single smmla instruction is emitted.
454class LowerContractionToNeonI8MMPattern
455 : public OpRewritePattern<vector::ContractionOp> {
456public:
458 LogicalResult matchAndRewrite(vector::ContractionOp op,
459 PatternRewriter &rewriter) const override {
460
461 VectorContractRewriterI8MM vcr;
462 if (failed(vcr.matchAndInit(op, rewriter)))
463 return failure();
464 vcr.lower(op, rewriter);
465
466 return success();
467 }
468};
469
470class LowerContractionToNeonBFMMLAPattern
471 : public OpRewritePattern<vector::ContractionOp> {
472public:
474 LogicalResult matchAndRewrite(vector::ContractionOp op,
475 PatternRewriter &rewriter) const override {
476
477 VectorContractRewriterBFMMLA vcr;
478 if (failed(vcr.matchAndInit(op, rewriter)))
479 return failure();
480 vcr.lower(op, rewriter);
481
482 return success();
483 }
484};
485
486} // namespace
487
490 MLIRContext *context = patterns.getContext();
491 patterns.add<LowerContractionToNeonI8MMPattern>(context, /*benefit=*/2);
492}
493
496 MLIRContext *context = patterns.getContext();
497 patterns.add<LowerContractionToNeonBFMMLAPattern>(context, /*benefit=*/2);
498}
return success()
lhs
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
FloatType getF32Type()
Definition Builders.cpp:47
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:328
FloatType getBF16Type()
Definition Builders.cpp:41
IntegerType getI8Type()
Definition Builders.cpp:63
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
This provides public APIs that all operations should have.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
void populateLowerContractionToNeonI8MMPatterns(RewritePatternSet &patterns)
void populateLowerContractionToNeonBFMMLAPatterns(RewritePatternSet &patterns)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition AffineMap.h:675
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...