31#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
46std::optional<Value> getExtOperand(
Value v) {
48 static_assert(llvm::is_one_of<Op, arith::ExtSIOp, arith::ExtUIOp>::value,
49 "Must be instantiated with either sign- or zero- extension op");
55 if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
56 auto vTy = cast<VectorType>(v.
getType());
57 if (!vTy.getElementType().isSignlessInteger(8))
66 auto inOp = extOp.getIn();
67 auto inTy = dyn_cast<VectorType>(inOp.getType());
68 if (!inTy || !inTy.getElementType().isSignlessInteger(8))
71 auto outTy = dyn_cast<VectorType>(extOp.getType());
72 if (!outTy || !outTy.getElementType().isSignlessInteger(32))
158class VectorContractRewriter {
171 MMLA mmlaOp = MMLA::Nop;
176 bool swapOperands =
false;
197 LogicalResult match(vector::ContractionOp op,
PatternRewriter &rewriter);
200 VectorContractRewriter() =
default;
216 case MMLA::SignedInt:
217 return arm_sve::SmmlaOp::create(rewriter, loc, resTy,
acc,
lhs,
rhs);
218 case MMLA::UnsignedInt:
219 return arm_sve::UmmlaOp::create(rewriter, loc, resTy,
acc,
lhs,
rhs);
221 return arm_sve::UsmmlaOp::create(rewriter, loc, resTy,
acc,
lhs,
rhs);
223 return arm_sve::BfmmlaOp::create(rewriter, loc, resTy,
acc,
lhs,
rhs);
225 llvm_unreachable(
"Uninitialized operation kind");
229LogicalResult VectorContractRewriter::match(vector::ContractionOp op,
232 auto itTypes = op.getIteratorTypesArray();
233 if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
234 itTypes[1] != vector::IteratorType::parallel ||
235 itTypes[2] != vector::IteratorType::reduction)
237 op,
"iterator types do not correspond to matrix multiplication");
244 if (op.getIndexingMapsArray()[0] !=
247 op.getIndexingMapsArray()[1] !=
251 3,
ArrayRef{0u, 1u}, op.getContext()))
255 if (op.getKind() != vector::CombiningKind::ADD)
261Value VectorContractRewriter::lower(vector::ContractionOp op,
265 Type operandEltType = cast<VectorType>(
lhs.getType()).getElementType();
266 Type resultEltType = cast<VectorType>(op.getResultType()).getElementType();
268 const int64_t numOperandSubTileElts =
272 "Only implemented for i32 or f32 output");
273 const int64_t numResultSubTileElts = 4;
278 VectorType::get(numOperandSubTileElts, operandEltType,
281 VectorType::get(numOperandSubTileElts, operandEltType,
284 VectorType::get(numResultSubTileElts, resultEltType,
289 auto flatRhsTileType = VectorType::get(k * n, operandEltType,
294 auto accRowTy = VectorType::get(n, resultEltType,
299 auto accRowX2Ty = VectorType::get(2 * n, resultEltType,
304 auto accRow64Ty = VectorType::get(n / 2, rewriter.
getI64Type(),
309 auto accRowX264Ty = VectorType::get(n, rewriter.
getI64Type(),
316 for (
int64_t i = 0; i < m; i += 2) {
324 std::iota(shuffleIdx.begin(), shuffleIdx.end(), 0);
325 auto t = vector::ShuffleOp::create(rewriter, loc, r0, r1, shuffleIdx);
327 auto s = vector::ScalableInsertOp::create(
328 rewriter, loc, t, ub::PoisonOp::create(rewriter, loc, flatLhsType), 0);
330 auto r = arm_sve::DupQLaneOp::create(rewriter, loc, s, 0);
331 lhsTile.push_back(r);
335 auto rhs = vector::ShapeCastOp::create(rewriter, this->rhs.
getLoc(),
336 flatRhsTileType, this->rhs);
341 rhsTile.push_back(vector::ScalableExtractOp::create(
342 rewriter, loc, flatRhsType,
rhs,
j * k));
346 for (
int64_t i = 0; i < m; i += 2) {
348 auto r0 = vector::ExtractOp::create(rewriter, loc, op.getAcc(),
350 auto r1 = vector::ExtractOp::create(rewriter, loc, op.getAcc(),
357 accTileVec = vector::InterleaveOp::create(rewriter, loc, r0, r1);
361 auto r0I64 = vector::BitCastOp::create(rewriter, loc, accRow64Ty, r0);
362 auto r1I64 = vector::BitCastOp::create(rewriter, loc, accRow64Ty, r1);
366 auto intrI64 = vector::InterleaveOp::create(rewriter, loc, r0I64, r1I64);
370 vector::BitCastOp::create(rewriter, loc, accRowX2Ty, intrI64);
374 accTile.push_back(vector::ScalableExtractOp::create(
375 rewriter, loc, flatAccType, accTileVec,
j * 2));
380 for (
int64_t i = 0; i < m / 2; ++i)
382 Value mmla = createMMLA(rewriter, loc, accTile[i * n / 2 +
j], lhsTile[i],
384 outTile.push_back(mmla);
388 Value result = ub::PoisonOp::create(rewriter, loc, op.getResultType());
389 for (
int64_t i = 0; i < m / 2; ++i) {
391 Value row = ub::PoisonOp::create(rewriter, loc, accRowX2Ty);
393 row = vector::ScalableInsertOp::create(
394 rewriter, loc, outTile[i * n / 2 +
j], row,
j * 4);
402 auto tmp = vector::DeinterleaveOp::create(rewriter, loc, row);
403 out0 = tmp.getRes1();
404 out1 = tmp.getRes2();
407 auto row64 = vector::BitCastOp::create(rewriter, loc, accRowX264Ty, row);
408 auto deintr64 = vector::DeinterleaveOp::create(rewriter, loc, row64);
411 out0 = vector::BitCastOp::create(rewriter, loc, accRowTy,
413 out1 = vector::BitCastOp::create(rewriter, loc, accRowTy,
416 result = vector::InsertOp::create(rewriter, loc, out0,
result, i * 2);
417 result = vector::InsertOp::create(rewriter, loc, out1,
result, i * 2 + 1);
423class VectorContractRewriterI8MM :
public VectorContractRewriter {
427 LogicalResult matchAndInit(vector::ContractionOp op,
429 if (failed(match(op, rewriter)))
432 VectorType lhsType = op.getLhsType();
433 VectorType rhsType = op.getRhsType();
435 m = lhsType.getDimSize(0);
436 n = rhsType.getDimSize(0);
437 k = rhsType.getDimSize(1);
444 if (lhsType.isScalable() || !rhsType.getScalableDims()[0] ||
445 rhsType.getScalableDims()[1] || lhsType.getDimSize(1) != k || k != 8 ||
446 m < 2 || m % 2 != 0 || n < 2 || n % 2 != 0 ||
447 !rhsType.getScalableDims()[0])
451 auto outTy = dyn_cast<VectorType>(op.getResultType());
452 if (!outTy || outTy.getElementType() != rewriter.
getI32Type())
454 "output type is not a vector of i32");
460 mmlaOp = MMLA::SignedInt;
461 swapOperands =
false;
462 auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
464 mmlaOp = MMLA::UnsignedInt;
465 maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
469 op,
"LHS is not a sign- or zero- extended i8");
471 auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs());
473 if (mmlaOp == MMLA::UnsignedInt)
474 mmlaOp = MMLA::MixedInt;
476 if (mmlaOp == MMLA::SignedInt) {
477 mmlaOp = MMLA::MixedInt;
480 maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs());
484 op,
"RHS is not a sign- or zero- extended i8");
495class VectorContractRewriterBfloat :
public VectorContractRewriter {
499 LogicalResult matchAndInit(vector::ContractionOp op,
501 if (failed(match(op, rewriter)))
504 VectorType lhsType = op.getLhsType();
505 VectorType rhsType = op.getRhsType();
507 m = lhsType.getDimSize(0);
508 n = rhsType.getDimSize(0);
509 k = rhsType.getDimSize(1);
516 if (lhsType.isScalable() || !rhsType.getScalableDims()[0] ||
517 rhsType.getScalableDims()[1] || lhsType.getDimSize(1) != k || k != 4 ||
518 m < 2 || m % 2 != 0 || n < 2 || n % 2 != 0 ||
519 !rhsType.getScalableDims()[0])
523 auto outTy = dyn_cast<VectorType>(op.getResultType());
524 if (!outTy || outTy.getElementType() != rewriter.
getF32Type())
526 "output type is not a vector of f32");
529 if (lhsType.getElementType() != rewriter.
getBF16Type())
531 "input type is not a vector of bf16");
534 mmlaOp = MMLA::Bfloat;
535 swapOperands =
false;
544class LowerContractionToSVEI8MMPattern
548 LogicalResult matchAndRewrite(vector::ContractionOp op,
552 VectorContractRewriterI8MM vcr;
553 if (failed(vcr.matchAndInit(op, rewriter)))
563class LowerContractionToSVEBFMMLAPattern
567 LogicalResult matchAndRewrite(vector::ContractionOp op,
571 VectorContractRewriterBfloat vcr;
572 if (failed(vcr.matchAndInit(op, rewriter)))
587 patterns.add<LowerContractionToSVEI8MMPattern>(context, 2);
593 patterns.add<LowerContractionToSVEBFMMLAPattern>(context, 2);
static AffineMap getMultiDimMapWithTargets(unsigned numDims, ArrayRef< unsigned > targets, MLIRContext *context)
Returns an affine map with numDims input dimensions and results specified by targets.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This provides public APIs that all operations should have.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateLowerContractionToSVEBFMMLAPatterns(RewritePatternSet &patterns)
void populateLowerContractionToSVEI8MMPatterns(RewritePatternSet &patterns)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.