28 #define DEBUG_TYPE "lower-contract-to-arm-neon"
44 template <
typename Op>
45 std::optional<Value> getExtOperand(
Value v) {
47 static_assert(llvm::is_one_of<Op, arith::ExtSIOp, arith::ExtUIOp>::value,
48 "Must be instantiated with either sign- or zero- extension op");
54 if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
55 auto eltTy = cast<VectorType>(v.
getType()).getElementType();
56 if (!eltTy.isSignlessInteger() || eltTy.getIntOrFloatBitWidth() > 8)
65 auto inOp = extOp.getIn();
66 auto inTy = dyn_cast<VectorType>(inOp.getType());
69 auto inEltTy = inTy.getElementType();
70 if (!inEltTy.isSignlessInteger() || inEltTy.getIntOrFloatBitWidth() > 8)
73 auto outTy = dyn_cast<VectorType>(extOp.getType());
74 if (!(outTy && outTy.getElementType().isSignlessInteger(32)))
86 return signExt ? rewriter.
createOrFold<arith::ExtSIOp>(loc, targetTy, val)
87 : rewriter.
createOrFold<arith::ExtUIOp>(loc, targetTy, val);
90 class VectorContractRewriter {
103 MMLA mmlaOp = MMLA::Nop;
108 bool swapOperands =
false;
138 case MMLA::SignedInt:
141 case MMLA::UnsignedInt:
148 return arm_neon::BfmmlaOp::create(rewriter, loc, acc.
getType(), acc, lhs,
151 llvm_unreachable(
"Uninitialized operation type");
157 LogicalResult matchAndInit(vector::ContractionOp op,
161 if ((itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
162 itTypes[1] != vector::IteratorType::parallel ||
163 itTypes[2] != vector::IteratorType::reduction) &&
164 (itTypes.size() != 2 || itTypes[0] != vector::IteratorType::parallel ||
165 itTypes[1] != vector::IteratorType::reduction))
167 op,
"iterator types do not correspond to matrix multiplication");
170 VectorType lhsType = op.getLhsType();
171 VectorType rhsType = op.getRhsType();
172 if (!lhsType.hasRank() || !rhsType.hasRank() || lhsType.getRank() > 2 ||
173 rhsType.getRank() != 2)
178 if (lhsType.isScalable() || rhsType.isScalable())
180 "Not applicable to scalable vectors");
183 dimM = lhsType.getDimSize(0);
184 dimN = rhsType.getDimSize(0);
185 dimK = rhsType.getDimSize(1);
188 if (lhsType.getRank() == 1) {
190 lhsDimK = lhsType.getDimSize(0);
192 lhsDimK = lhsType.getDimSize(1);
204 auto inputElementType = cast<ShapedType>(lhs.
getType()).getElementType();
205 auto accElementType = cast<ShapedType>(acc.
getType()).getElementType();
206 auto inputExpandedType =
212 auto collapsedInputType =
214 auto collapsedOutputType =
218 auto indexingMaps = op.getIndexingMapsArray();
219 AffineMap &lhsPermutationMap = indexingMaps[0];
220 AffineMap &rhsPermutationMap = indexingMaps[1];
221 AffineMap &accPermutationMap = indexingMaps[2];
228 arith::ConstantOp::create(rewriter, loc, op.getResultType(),
232 if (iterationBounds.size() == 3)
233 loopOrder.push_back(2);
240 auto extractOperand = [&](
Value operand,
AffineMap permutationMap,
245 return rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
246 loc, operand, operandOffsets, operandShape, operandStrides);
252 Value tiledLhs = extractOperand(lhs, lhsPermutationMap, lhsOffsets);
255 Value tiledRhs = extractOperand(rhs, rhsPermutationMap, rhsOffsets);
258 Value tiledAcc = extractOperand(acc, accPermutationMap, accOffsets);
263 auto expandRowVector = [&](
Value tiledOperand,
264 VectorType expandedTypeType) {
266 arith::ConstantOp::create(rewriter, loc, expandedTypeType,
269 cast<ShapedType>(emptyOperand.getType()).getRank(), 0);
271 cast<ShapedType>(tiledOperand.
getType()).getRank(), 1);
272 return rewriter.
createOrFold<vector::InsertStridedSliceOp>(
273 loc, tiledOperand, emptyOperand, offsets, strides);
275 tiledLhs = expandRowVector(tiledLhs, inputExpandedType);
276 tiledAcc = expandRowVector(tiledAcc, outputExpandedType);
283 tiledAcc = vector::TransposeOp::create(rewriter, loc, tiledAcc,
287 auto collapsedLhs = rewriter.
createOrFold<vector::ShapeCastOp>(
288 tiledLhs.
getLoc(), collapsedInputType, tiledLhs);
289 auto collapsedRhs = rewriter.
createOrFold<vector::ShapeCastOp>(
290 tiledRhs.
getLoc(), collapsedInputType, tiledRhs);
292 bool initialKAcc = offsets.back() == 0;
297 collapsedRes = rewriter.
createOrFold<vector::ShapeCastOp>(
298 tiledAcc.
getLoc(), collapsedOutputType, tiledAcc);
303 createMMLA(rewriter, loc, collapsedRes, collapsedLhs, collapsedRhs);
312 tiledRes = vector::TransposeOp::create(rewriter, loc, tiledRes,
318 tiledRes = rewriter.
createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
323 cast<ShapedType>(tiledRes.
getType()).getRank(), 1);
324 result = rewriter.
createOrFold<vector::InsertStridedSliceOp>(
325 loc, tiledRes, result, accOffsets, strides);
332 class VectorContractRewriterI8MM :
public VectorContractRewriter {
334 LogicalResult matchAndInit(vector::ContractionOp op,
336 if (
failed(VectorContractRewriter::matchAndInit(op, rewriter)))
341 if ((dimM != 1 && dimM % 2 != 0) || dimN % 2 != 0 || dimK % 8 != 0)
348 mmlaOp = MMLA::SignedInt;
349 auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
351 mmlaOp = MMLA::UnsignedInt;
352 maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
356 op,
"LHS is not a sign- or zero- extended iN, N <= 8");
358 auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs());
360 if (mmlaOp == MMLA::UnsignedInt)
361 mmlaOp = MMLA::MixedInt;
363 if (mmlaOp == MMLA::SignedInt) {
364 mmlaOp = MMLA::MixedInt;
367 maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs());
372 op,
"RHS is not a sign- or zero- extended iN, N <= 8");
380 auto lhsExtInType = cast<VectorType>(lhs.getType());
381 if (lhsExtInType.getElementTypeBitWidth() < 8)
382 lhs = extendSmallIntVector(loc, lhsExtInType, lhs,
384 (mmlaOp == MMLA::SignedInt ||
385 (mmlaOp == MMLA::MixedInt && !swapOperands)),
388 auto rhsExtInType = cast<VectorType>(rhs.getType());
389 if (rhsExtInType.getElementTypeBitWidth() < 8)
390 rhs = extendSmallIntVector(loc, rhsExtInType, rhs,
392 (mmlaOp == MMLA::SignedInt ||
393 (mmlaOp == MMLA::MixedInt && swapOperands)),
397 iterationBounds = *op.getShapeForUnroll();
398 if (iterationBounds.size() == 3)
407 class VectorContractRewriterBFMMLA :
public VectorContractRewriter {
409 LogicalResult matchAndInit(vector::ContractionOp op,
412 if (
failed(VectorContractRewriter::matchAndInit(op, rewriter)))
417 if ((dimM != 1 && dimM % 2 != 0) || dimN % 2 != 0 || dimK % 4 != 0)
421 auto outTy = dyn_cast<VectorType>(op.getResultType());
422 if (!outTy || outTy.getElementType() != rewriter.
getF32Type())
424 "output type is not a vector of f32");
427 if (op.getLhsType().getElementType() != rewriter.
getBF16Type())
429 "input type is not a vector of bf16");
431 mmlaOp = MMLA::Bfloat;
432 swapOperands =
false;
438 iterationBounds = *op.getShapeForUnroll();
439 if (iterationBounds.size() == 3)
453 class LowerContractionToNeonI8MMPattern
457 LogicalResult matchAndRewrite(vector::ContractionOp op,
460 VectorContractRewriterI8MM vcr;
461 if (
failed(vcr.matchAndInit(op, rewriter)))
463 vcr.lower(op, rewriter);
469 class LowerContractionToNeonBFMMLAPattern
473 LogicalResult matchAndRewrite(vector::ContractionOp op,
476 VectorContractRewriterBFMMLA vcr;
477 if (
failed(vcr.matchAndInit(op, rewriter)))
479 vcr.lower(op, rewriter);
490 patterns.add<LowerContractionToNeonI8MMPattern>(context, 2);
496 patterns.add<LowerContractionToNeonBFMMLAPattern>(context, 2);
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
TypedAttr getZeroAttr(Type type)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This provides public APIs that all operations should have.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
void populateLowerContractionToNeonI8MMPatterns(RewritePatternSet &patterns)
void populateLowerContractionToNeonBFMMLAPatterns(RewritePatternSet &patterns)
Include the generated interface declarations.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...