31 #define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
45 template <
typename Op>
46 std::optional<Value> getExtOperand(
Value v) {
48 static_assert(llvm::is_one_of<Op, arith::ExtSIOp, arith::ExtUIOp>::value,
49 "Must be instantiated with either sign- or zero- extension op");
55 if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
56 auto vTy = cast<VectorType>(v.
getType());
57 if (!vTy.getElementType().isSignlessInteger(8))
66 auto inOp = extOp.getIn();
67 auto inTy = dyn_cast<VectorType>(inOp.getType());
68 if (!inTy || !inTy.getElementType().isSignlessInteger(8))
71 auto outTy = dyn_cast<VectorType>(extOp.getType());
72 if (!outTy || !outTy.getElementType().isSignlessInteger(32))
158 class VectorContractRewriter {
171 MMLA mmlaOp = MMLA::Nop;
176 bool swapOperands =
false;
197 LogicalResult match(vector::ContractionOp op,
PatternRewriter &rewriter);
200 VectorContractRewriter() =
default;
216 case MMLA::SignedInt:
217 return arm_sve::SmmlaOp::create(rewriter, loc, resTy, acc, lhs, rhs);
218 case MMLA::UnsignedInt:
219 return arm_sve::UmmlaOp::create(rewriter, loc, resTy, acc, lhs, rhs);
221 return arm_sve::UsmmlaOp::create(rewriter, loc, resTy, acc, lhs, rhs);
223 return arm_sve::BfmmlaOp::create(rewriter, loc, resTy, acc, lhs, rhs);
225 llvm_unreachable(
"Uninitialized operation kind");
229 LogicalResult VectorContractRewriter::match(vector::ContractionOp op,
232 auto itTypes = op.getIteratorTypesArray();
233 if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
234 itTypes[1] != vector::IteratorType::parallel ||
235 itTypes[2] != vector::IteratorType::reduction)
237 op,
"iterator types do not correspond to matrix multiplication");
244 if (op.getIndexingMapsArray()[0] !=
247 op.getIndexingMapsArray()[1] !=
251 3,
ArrayRef{0u, 1u}, op.getContext()))
255 if (op.getKind() != vector::CombiningKind::ADD)
261 Value VectorContractRewriter::lower(vector::ContractionOp op,
265 Type operandEltType = cast<VectorType>(lhs.
getType()).getElementType();
266 Type resultEltType = cast<VectorType>(op.getResultType()).getElementType();
268 const int64_t numOperandSubTileElts =
272 "Only implemented for i32 or f32 output");
273 const int64_t numResultSubTileElts = 4;
316 for (int64_t i = 0; i < m; i += 2) {
324 std::iota(shuffleIdx.begin(), shuffleIdx.end(), 0);
325 auto t = vector::ShuffleOp::create(rewriter, loc, r0, r1, shuffleIdx);
327 auto s = vector::ScalableInsertOp::create(
328 rewriter, loc, t, ub::PoisonOp::create(rewriter, loc, flatLhsType), 0);
330 auto r = arm_sve::DupQLaneOp::create(rewriter, loc, s, 0);
331 lhsTile.push_back(r);
335 auto rhs = vector::ShapeCastOp::create(rewriter, this->rhs.
getLoc(),
336 flatRhsTileType, this->rhs);
340 for (int64_t
j = 0;
j < n;
j += 2)
341 rhsTile.push_back(vector::ScalableExtractOp::create(
342 rewriter, loc, flatRhsType, rhs,
j * k));
346 for (int64_t i = 0; i < m; i += 2) {
348 auto r0 = vector::ExtractOp::create(rewriter, loc, op.getAcc(),
350 auto r1 = vector::ExtractOp::create(rewriter, loc, op.getAcc(),
357 accTileVec = vector::InterleaveOp::create(rewriter, loc, r0, r1);
361 auto r0I64 = vector::BitCastOp::create(rewriter, loc, accRow64Ty, r0);
362 auto r1I64 = vector::BitCastOp::create(rewriter, loc, accRow64Ty, r1);
366 auto intrI64 = vector::InterleaveOp::create(rewriter, loc, r0I64, r1I64);
370 vector::BitCastOp::create(rewriter, loc, accRowX2Ty, intrI64);
373 for (int64_t
j = 0;
j < n;
j += 2)
374 accTile.push_back(vector::ScalableExtractOp::create(
375 rewriter, loc, flatAccType, accTileVec,
j * 2));
380 for (int64_t i = 0; i < m / 2; ++i)
381 for (int64_t
j = 0;
j < n / 2; ++
j) {
382 Value mmla = createMMLA(rewriter, loc, accTile[i * n / 2 +
j], lhsTile[i],
384 outTile.push_back(mmla);
388 Value result = ub::PoisonOp::create(rewriter, loc, op.getResultType());
389 for (int64_t i = 0; i < m / 2; ++i) {
391 Value row = ub::PoisonOp::create(rewriter, loc, accRowX2Ty);
392 for (int64_t
j = 0;
j < n / 2; ++
j)
393 row = vector::ScalableInsertOp::create(
394 rewriter, loc, outTile[i * n / 2 +
j], row,
j * 4);
402 auto tmp = vector::DeinterleaveOp::create(rewriter, loc, row);
403 out0 = tmp.getRes1();
404 out1 = tmp.getRes2();
407 auto row64 = vector::BitCastOp::create(rewriter, loc, accRowX264Ty, row);
408 auto deintr64 = vector::DeinterleaveOp::create(rewriter, loc, row64);
411 out0 = vector::BitCastOp::create(rewriter, loc, accRowTy,
413 out1 = vector::BitCastOp::create(rewriter, loc, accRowTy,
416 result = vector::InsertOp::create(rewriter, loc, out0, result, i * 2);
417 result = vector::InsertOp::create(rewriter, loc, out1, result, i * 2 + 1);
423 class VectorContractRewriterI8MM :
public VectorContractRewriter {
427 LogicalResult matchAndInit(vector::ContractionOp op,
429 if (
failed(match(op, rewriter)))
432 VectorType lhsType = op.getLhsType();
433 VectorType rhsType = op.getRhsType();
435 m = lhsType.getDimSize(0);
436 n = rhsType.getDimSize(0);
437 k = rhsType.getDimSize(1);
444 if (lhsType.isScalable() || !rhsType.getScalableDims()[0] ||
445 rhsType.getScalableDims()[1] || lhsType.getDimSize(1) != k || k != 8 ||
446 m < 2 || m % 2 != 0 || n < 2 || n % 2 != 0 ||
447 !rhsType.getScalableDims()[0])
451 auto outTy = dyn_cast<VectorType>(op.getResultType());
452 if (!outTy || outTy.getElementType() != rewriter.
getI32Type())
454 "output type is not a vector of i32");
460 mmlaOp = MMLA::SignedInt;
461 swapOperands =
false;
462 auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
464 mmlaOp = MMLA::UnsignedInt;
465 maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
469 op,
"LHS is not a sign- or zero- extended i8");
471 auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs());
473 if (mmlaOp == MMLA::UnsignedInt)
474 mmlaOp = MMLA::MixedInt;
476 if (mmlaOp == MMLA::SignedInt) {
477 mmlaOp = MMLA::MixedInt;
480 maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs());
484 op,
"RHS is not a sign- or zero- extended i8");
495 class VectorContractRewriterBfloat :
public VectorContractRewriter {
499 LogicalResult matchAndInit(vector::ContractionOp op,
501 if (
failed(match(op, rewriter)))
504 VectorType lhsType = op.getLhsType();
505 VectorType rhsType = op.getRhsType();
507 m = lhsType.getDimSize(0);
508 n = rhsType.getDimSize(0);
509 k = rhsType.getDimSize(1);
516 if (lhsType.isScalable() || !rhsType.getScalableDims()[0] ||
517 rhsType.getScalableDims()[1] || lhsType.getDimSize(1) != k || k != 4 ||
518 m < 2 || m % 2 != 0 || n < 2 || n % 2 != 0 ||
519 !rhsType.getScalableDims()[0])
523 auto outTy = dyn_cast<VectorType>(op.getResultType());
524 if (!outTy || outTy.getElementType() != rewriter.
getF32Type())
526 "output type is not a vector of f32");
529 if (lhsType.getElementType() != rewriter.
getBF16Type())
531 "input type is not a vector of bf16");
534 mmlaOp = MMLA::Bfloat;
535 swapOperands =
false;
544 class LowerContractionToSVEI8MMPattern
548 LogicalResult matchAndRewrite(vector::ContractionOp op,
552 VectorContractRewriterI8MM vcr;
553 if (
failed(vcr.matchAndInit(op, rewriter)))
556 Value result = vcr.lower(op, rewriter);
563 class LowerContractionToSVEBFMMLAPattern
567 LogicalResult matchAndRewrite(vector::ContractionOp op,
571 VectorContractRewriterBfloat vcr;
572 if (
failed(vcr.matchAndInit(op, rewriter)))
575 Value result = vcr.lower(op, rewriter);
587 patterns.add<LowerContractionToSVEI8MMPattern>(context, 2);
593 patterns.add<LowerContractionToSVEBFMMLAPattern>(context, 2);
static AffineMap getMultiDimMapWithTargets(unsigned numDims, ArrayRef< unsigned > targets, MLIRContext *context)
Returns an affine map with numDims input dimensions and results specified by targets.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This provides public APIs that all operations should have.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateLowerContractionToSVEBFMMLAPatterns(RewritePatternSet &patterns)
void populateLowerContractionToSVEI8MMPatterns(RewritePatternSet &patterns)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.