28#define DEBUG_TYPE "lower-contract-to-arm-neon"
45std::optional<Value> getExtOperand(
Value v) {
47 static_assert(llvm::is_one_of<Op, arith::ExtSIOp, arith::ExtUIOp>::value,
48 "Must be instantiated with either sign- or zero- extension op");
54 if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
55 auto eltTy = cast<VectorType>(v.
getType()).getElementType();
56 if (!eltTy.isSignlessInteger() || eltTy.getIntOrFloatBitWidth() > 8)
65 auto inOp = extOp.getIn();
66 auto inTy = dyn_cast<VectorType>(inOp.getType());
69 auto inEltTy = inTy.getElementType();
70 if (!inEltTy.isSignlessInteger() || inEltTy.getIntOrFloatBitWidth() > 8)
73 auto outTy = dyn_cast<VectorType>(extOp.getType());
74 if (!(outTy && outTy.getElementType().isSignlessInteger(32)))
86 return signExt ? rewriter.
createOrFold<arith::ExtSIOp>(loc, targetTy, val)
87 : rewriter.
createOrFold<arith::ExtUIOp>(loc, targetTy, val);
90class VectorContractRewriter {
103 MMLA mmlaOp = MMLA::Nop;
108 bool swapOperands =
false;
138 case MMLA::SignedInt:
141 case MMLA::UnsignedInt:
148 return arm_neon::BfmmlaOp::create(rewriter, loc,
acc.getType(),
acc,
lhs,
151 llvm_unreachable(
"Uninitialized operation type");
153 llvm_unreachable(
"Unknown MMLA");
158 LogicalResult matchAndInit(vector::ContractionOp op,
162 if ((itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
163 itTypes[1] != vector::IteratorType::parallel ||
164 itTypes[2] != vector::IteratorType::reduction) &&
165 (itTypes.size() != 2 || itTypes[0] != vector::IteratorType::parallel ||
166 itTypes[1] != vector::IteratorType::reduction))
168 op,
"iterator types do not correspond to matrix multiplication");
171 VectorType lhsType = op.getLhsType();
172 VectorType rhsType = op.getRhsType();
173 if (!lhsType.hasRank() || !rhsType.hasRank() || lhsType.getRank() > 2 ||
174 rhsType.getRank() != 2)
179 if (lhsType.isScalable() || rhsType.isScalable())
181 "Not applicable to scalable vectors");
184 dimM = lhsType.getDimSize(0);
185 dimN = rhsType.getDimSize(0);
186 dimK = rhsType.getDimSize(1);
189 if (lhsType.getRank() == 1) {
191 lhsDimK = lhsType.getDimSize(0);
193 lhsDimK = lhsType.getDimSize(1);
205 auto inputElementType = cast<ShapedType>(
lhs.getType()).getElementType();
206 auto accElementType = cast<ShapedType>(
acc.getType()).getElementType();
207 auto inputExpandedType =
208 VectorType::get({2, subTileShape.back()}, inputElementType);
209 auto outputExpandedType = VectorType::get({2, 2}, accElementType);
213 auto collapsedInputType =
214 VectorType::get(inputExpandedType.getNumElements(), inputElementType);
215 auto collapsedOutputType =
216 VectorType::get(outputExpandedType.getNumElements(), accElementType);
219 auto indexingMaps = op.getIndexingMapsArray();
220 AffineMap &lhsPermutationMap = indexingMaps[0];
221 AffineMap &rhsPermutationMap = indexingMaps[1];
222 AffineMap &accPermutationMap = indexingMaps[2];
229 arith::ConstantOp::create(rewriter, loc, op.getResultType(),
233 if (iterationBounds.size() == 3)
234 loopOrder.push_back(2);
241 auto extractOperand = [&](
Value operand,
AffineMap permutationMap,
246 return rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
247 loc, operand, operandOffsets, operandShape, operandStrides);
253 Value tiledLhs = extractOperand(
lhs, lhsPermutationMap, lhsOffsets);
256 Value tiledRhs = extractOperand(
rhs, rhsPermutationMap, rhsOffsets);
259 Value tiledAcc = extractOperand(
acc, accPermutationMap, accOffsets);
264 auto expandRowVector = [&](
Value tiledOperand,
265 VectorType expandedTypeType) {
267 arith::ConstantOp::create(rewriter, loc, expandedTypeType,
270 cast<ShapedType>(emptyOperand.getType()).getRank(), 0);
272 cast<ShapedType>(tiledOperand.
getType()).getRank(), 1);
273 return rewriter.
createOrFold<vector::InsertStridedSliceOp>(
274 loc, tiledOperand, emptyOperand, offsets, strides);
276 tiledLhs = expandRowVector(tiledLhs, inputExpandedType);
277 tiledAcc = expandRowVector(tiledAcc, outputExpandedType);
284 tiledAcc = vector::TransposeOp::create(rewriter, loc, tiledAcc,
288 auto collapsedLhs = rewriter.
createOrFold<vector::ShapeCastOp>(
289 tiledLhs.
getLoc(), collapsedInputType, tiledLhs);
290 auto collapsedRhs = rewriter.
createOrFold<vector::ShapeCastOp>(
291 tiledRhs.
getLoc(), collapsedInputType, tiledRhs);
293 bool initialKAcc = offsets.back() == 0;
298 collapsedRes = rewriter.
createOrFold<vector::ShapeCastOp>(
299 tiledAcc.
getLoc(), collapsedOutputType, tiledAcc);
304 createMMLA(rewriter, loc, collapsedRes, collapsedLhs, collapsedRhs);
313 tiledRes = vector::TransposeOp::create(rewriter, loc, tiledRes,
319 tiledRes = rewriter.
createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
324 cast<ShapedType>(tiledRes.
getType()).getRank(), 1);
326 loc, tiledRes,
result, accOffsets, strides);
333class VectorContractRewriterI8MM :
public VectorContractRewriter {
335 LogicalResult matchAndInit(vector::ContractionOp op,
337 if (failed(VectorContractRewriter::matchAndInit(op, rewriter)))
342 if ((dimM != 1 && dimM % 2 != 0) || dimN % 2 != 0 || dimK % 8 != 0)
349 mmlaOp = MMLA::SignedInt;
350 auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
352 mmlaOp = MMLA::UnsignedInt;
353 maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
357 op,
"LHS is not a sign- or zero- extended iN, N <= 8");
359 auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs());
361 if (mmlaOp == MMLA::UnsignedInt)
362 mmlaOp = MMLA::MixedInt;
364 if (mmlaOp == MMLA::SignedInt) {
365 mmlaOp = MMLA::MixedInt;
368 maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs());
373 op,
"RHS is not a sign- or zero- extended iN, N <= 8");
381 auto lhsExtInType = cast<VectorType>(
lhs.getType());
382 if (lhsExtInType.getElementTypeBitWidth() < 8)
383 lhs = extendSmallIntVector(loc, lhsExtInType,
lhs,
385 (mmlaOp == MMLA::SignedInt ||
386 (mmlaOp == MMLA::MixedInt && !swapOperands)),
389 auto rhsExtInType = cast<VectorType>(
rhs.getType());
390 if (rhsExtInType.getElementTypeBitWidth() < 8)
391 rhs = extendSmallIntVector(loc, rhsExtInType,
rhs,
393 (mmlaOp == MMLA::SignedInt ||
394 (mmlaOp == MMLA::MixedInt && swapOperands)),
398 iterationBounds = *op.getShapeForUnroll();
399 if (iterationBounds.size() == 3)
408class VectorContractRewriterBFMMLA :
public VectorContractRewriter {
410 LogicalResult matchAndInit(vector::ContractionOp op,
413 if (failed(VectorContractRewriter::matchAndInit(op, rewriter)))
418 if ((dimM != 1 && dimM % 2 != 0) || dimN % 2 != 0 || dimK % 4 != 0)
422 auto outTy = dyn_cast<VectorType>(op.getResultType());
423 if (!outTy || outTy.getElementType() != rewriter.
getF32Type())
425 "output type is not a vector of f32");
428 if (op.getLhsType().getElementType() != rewriter.
getBF16Type())
430 "input type is not a vector of bf16");
432 mmlaOp = MMLA::Bfloat;
433 swapOperands =
false;
439 iterationBounds = *op.getShapeForUnroll();
440 if (iterationBounds.size() == 3)
454class LowerContractionToNeonI8MMPattern
458 LogicalResult matchAndRewrite(vector::ContractionOp op,
461 VectorContractRewriterI8MM vcr;
462 if (failed(vcr.matchAndInit(op, rewriter)))
464 vcr.lower(op, rewriter);
470class LowerContractionToNeonBFMMLAPattern
474 LogicalResult matchAndRewrite(vector::ContractionOp op,
477 VectorContractRewriterBFMMLA vcr;
478 if (failed(vcr.matchAndInit(op, rewriter)))
480 vcr.lower(op, rewriter);
491 patterns.add<LowerContractionToNeonI8MMPattern>(context, 2);
497 patterns.add<LowerContractionToNeonBFMMLAPattern>(context, 2);
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
TypedAttr getZeroAttr(Type type)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This provides public APIs that all operations should have.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
void populateLowerContractionToNeonI8MMPatterns(RewritePatternSet &patterns)
void populateLowerContractionToNeonBFMMLAPatterns(RewritePatternSet &patterns)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...