19#include "llvm/Support/Casting.h"
34 ShapedType inputType = cast<ShapedType>(input.
getType());
35 int64_t firstDimToCollapse = inputType.getRank() - 2;
37 if (inputType.getRank() == 1)
41 for (
int64_t i = 0; i < firstDimToCollapse; ++i)
45 for (
int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
46 collapsedIndices.push_back(i);
48 reassociation.push_back(collapsedIndices);
49 return memref::CollapseShapeOp::create(builder, loc, input, reassociation);
54static FailureOr<std::pair<Value, SmallVector<Value>>>
64 .Case<TransferReadOp, LoadOp>([&](
auto readOp) {
66 readOp.getIndices().end());
67 srcBuff = readOp.getOperand(0);
77 indices.reserve(indexVals.size());
88 return std::make_pair(srcBuff,
indices);
92static LogicalResult validateContractOps(
OpBuilder &rewriter,
93 vector::ContractionOp contractOp,
94 unsigned int blockingFactor,
100 auto srcIndxLhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
101 contractOp.getLhs(),
false);
104 auto [buffLhs, indicesLhs] = *srcIndxLhs;
107 auto srcIndxRhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
108 contractOp.getRhs(),
false);
111 auto [buffRhs, indicesRhs] = *srcIndxRhs;
114 if (buffLhs != srcBuffLhs)
117 if (buffRhs != srcBuffRhs)
121 VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
128 llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
129 [](
int64_t dim) {
return (dim != 16 && dim != 1); });
131 if (nonUnitDimAcc.size() != 0)
136 VectorType lhsTy = contractOp.getLhsType();
139 llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
140 [](
int64_t dim) {
return (dim != 16 && dim != 1); });
142 if (nonUnitDimLhs.size() != 1)
145 if (nonUnitDimLhs[0] != blockingFactor)
150 VectorType rhsTy = contractOp.getRhsType();
153 llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
154 [](
int64_t dim) {
return (dim != 16 && dim != 1); });
156 if (nonUnitDimRhs.size() != 1)
159 if (nonUnitDimRhs[0] != blockingFactor)
167static unsigned getIndexPosition(
Value operand, scf::ForOp loop) {
168 Value iv = loop.getInductionVar();
172 .Case<TransferReadOp, LoadOp>(
173 [&](
auto readOp) { srcBuff = readOp.getOperand(0); });
179 auto offsets = subview.getOffsets();
181 for (
auto it : llvm::enumerate(offsets)) {
182 if (it.value() == iv)
192 bool rhs,
unsigned int offset) {
194 auto srcIndx = getSrcIndxValue(rewriter, loc, operand,
false);
195 auto [srcBuff,
indices] = *srcIndx;
204 amx::TileType tileType = amx::TileType::get({16, (16 * offset)}, ipType);
205 return amx::TileLoadOp::create(rewriter, loc, tileType, mat,
indices);
213 unsigned int offset) {
224 for (
size_t i = 0; i < ops.size(); i++) {
226 Operation *readOpLhs = ops[i].getLhs().getDefiningOp();
227 amx::TileLoadOp tilesLhs;
228 auto itLhs = readsToTileLoads.find(readOpLhs);
229 if (itLhs != readsToTileLoads.end()) {
230 tilesLhs = itLhs->second;
232 tilesLhs = createTileLoads(rewriter, loc, ops[i].getLhs(),
233 subviewCollapseLhs, ipType,
false, offset);
234 readsToTileLoads.try_emplace(readOpLhs, tilesLhs);
237 Operation *readOpRhs = ops[i].getRhs().getDefiningOp();
238 amx::TileLoadOp tilesRhs;
239 auto itRhs = readsToTileLoads.find(readOpRhs);
240 if (itRhs != readsToTileLoads.end()) {
241 tilesRhs = itRhs->second;
243 tilesRhs = createTileLoads(rewriter, loc, ops[i].getRhs(),
244 subviewCollapseRhs, ipType,
true, offset);
245 readsToTileLoads.try_emplace(readOpRhs, tilesRhs);
248 auto accTileType = amx::TileType::get({16, 16}, opType);
252 dp = amx::TileMulFOp::create(rewriter, loc, accTileType, tilesLhs,
253 tilesRhs, accIterArgs[i]);
256 dp = amx::TileMulIOp::create(rewriter, loc, accTileType, tilesLhs,
257 tilesRhs, accIterArgs[i]);
259 accumulators.push_back(dp);
265 Type opType, scf::ForOp outerLoop,
270 auto zeroTileType = amx::TileType::get({16, 16}, opType);
272 for (
int i = 0; i < size; i++) {
273 auto zeroTile = amx::TileZeroOp::create(rewriter, loc, zeroTileType);
274 loopItrArgs.push_back(zeroTile);
296struct VectorContractToAMXDotProduct
298 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
300 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
301 PatternRewriter &rewriter)
const override {
303 if (contractOp.getKind() != vector::CombiningKind::ADD)
305 "Expects add combining kind.");
307 unsigned int blockingFactor =
308 contractOp.getLhsType().getElementType().isBF16() ? 2 : 4;
311 contractOp.getIndexingMapsArray(), blockingFactor);
313 VectorType lhsTy = contractOp.getLhsType();
314 if (!lhsTy.getElementType().isBF16() &&
315 !lhsTy.getElementType().isSignlessInteger(8))
317 contractOp,
"Only BF16/Int8 lowering is supported.");
319 VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
323 if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()) ||
324 (lhsTy.getElementType().isSignlessInteger(8) &&
325 !accTy.getElementType().isSignlessInteger(32)))
327 "Only F32 for BF16 or Int32 for Int8 "
328 "accumulation type is supported.");
331 contractOp,
"Only VNNI-packed inputs are supported.");
333 Operation *accReadOp =
336 Operation *resultWriteOp =
339 if (!accReadOp || !resultWriteOp)
341 contractOp,
"The ACC operand of the vector.contract should be a "
342 "transfer_read or a load. And, the result should be "
343 "stored using transfer_write or store.");
348 if (lhsTy.getElementType().isSignlessInteger(8)) {
353 if (accReadOp->
getBlock() == contractOp->getBlock() &&
354 resultWriteOp->
getBlock() != contractOp->getBlock())
356 contractOp,
"The accumulator store is in different block.");
358 if (accReadOp->
getBlock() != contractOp->getBlock() &&
359 resultWriteOp->
getBlock() == contractOp->getBlock())
361 contractOp,
"The accumulator read is in different block.");
365 if (accReadOp->
getBlock() == contractOp->getBlock() &&
366 resultWriteOp->
getBlock() == contractOp->getBlock()) {
368 LogicalResult validate = validateContractOps(
369 rewriter, contractOp, blockingFactor, Value(), Value(),
false);
373 contractOp,
"The contract operation doesn't satisfy the operands "
374 "dimensions. M, N, and vnni dims are 16, 16, and 2/4. "
375 "The rest dims should be 1.");
377 Location loc = contractOp.getLoc();
379 auto srcIndxLhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
380 contractOp.getLhs(),
true);
383 "The LHS src is not a MemRef type.");
384 auto [srcBuffLhs, indicesLhs] = *srcIndxLhs;
386 auto srcIndxRhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
387 contractOp.getRhs(),
true);
390 "The RHS src is not a MemRef type.");
391 auto [srcBuffRhs, indicesRhs] = *srcIndxRhs;
393 auto srcIndxAcc = getSrcIndxValue(rewriter, contractOp.getLoc(),
394 contractOp.getAcc(),
false);
397 "The ACC src is not a MemRef type.");
398 auto [srcBuffAcc, indicesAcc] = *srcIndxAcc;
401 auto tileType = amx::TileType::get({16, (16 * blockingFactor)}, ipType);
402 auto loadLhs = amx::TileLoadOp::create(rewriter, loc, tileType,
403 srcBuffLhs, indicesLhs);
404 auto loadRhs = amx::TileLoadOp::create(rewriter, loc, tileType,
405 srcBuffRhs, indicesRhs);
407 auto tileTypeAcc = amx::TileType::get({16, 16}, opType);
408 auto loadAcc = amx::TileLoadOp::create(rewriter, loc, tileTypeAcc,
409 srcBuffAcc, indicesAcc);
414 dp = amx::TileMulFOp::create(rewriter, loc, tileTypeAcc, loadLhs,
418 dp = amx::TileMulIOp::create(rewriter, loc, tileTypeAcc, loadLhs,
421 amx::TileStoreOp::create(rewriter, loc, srcBuffAcc, indicesAcc, dp);
423 rewriter.
eraseOp(resultWriteOp);
430 SmallVector<scf::ForOp> loopLists;
431 Operation *current = contractOp;
435 loopLists.push_back(dyn_cast<scf::ForOp>(parent));
444 if (loopLists.size() > 2 || loopLists.size() == 0)
446 contractOp,
"Rewrite is supported until reduction loop depth of 2.");
448 auto srcIndxLhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
449 contractOp.getLhs(),
false);
452 "The LHS src is not a MemRef type.");
453 auto [srcBuffLhs, indicesLhs] = *srcIndxLhs;
455 auto srcIndxRhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
456 contractOp.getRhs(),
false);
459 "The RHS src is not a MemRef type.");
460 auto [srcBuffRhs, indicesRhs] = *srcIndxRhs;
462 Operation *vectorOpLhs;
463 llvm::TypeSwitch<Operation *>(contractOp.getLhs().getDefiningOp())
464 .Case<TransferReadOp, LoadOp>([&](
auto readOp) {
465 vectorOpLhs = readOp.getBase().getDefiningOp();
468 Operation *vectorOpRhs;
469 llvm::TypeSwitch<Operation *>(contractOp.getRhs().getDefiningOp())
470 .Case<TransferReadOp, LoadOp>([&](
auto readOp) {
471 vectorOpRhs = readOp.getBase().getDefiningOp();
475 SmallVector<vector::ContractionOp> ops;
476 for (mlir::Operation &op : loopLists[0].getBody()->getOperations()) {
478 if (
auto contract = llvm::dyn_cast<mlir::vector::ContractionOp>(op)) {
480 LogicalResult validate = validateContractOps(
481 rewriter,
contract, blockingFactor, srcBuffLhs, srcBuffRhs,
true);
485 contractOp,
"The associated contract operations doesn't satisfy "
486 "the re-write conditions either the dimensions are "
487 "wrong or MemRef source are different.");
493 scf::ForOp outerLoop;
494 scf::ForOp innerLoop;
498 if (loopLists.size() == 2) {
499 outerLoop = loopLists[1];
500 innerLoop = loopLists[0];
502 SmallVector<Value> loopItrArgs = createTileZeros(
503 rewriter, outerLoop.getLoc(), opType, outerLoop, ops.size());
505 newLoop = scf::ForOp::create(
506 rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(),
507 outerLoop.getUpperBound(), outerLoop.getStep(), loopItrArgs,
508 [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
509 Value ivOuterLoop,
ValueRange iterArgsOuterLoop) {
510 auto newInnerLoop = scf::ForOp::create(
511 rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
512 innerLoop.getUpperBound(), innerLoop.getStep(),
514 [&](OpBuilder &rewriterNewInnerLoop, Location locNewInnerLoop,
515 Value ivNewInnerLoop, ValueRange iterArgsNewInnerLoop) {
518 vectorOpLhs->getOperand(
519 getIndexPosition(contractOp.getLhs(), outerLoop) + 1),
522 vectorOpLhs->getOperand(
523 getIndexPosition(contractOp.getLhs(), innerLoop) + 1),
526 rewriterNewInnerLoop.clone(*vectorOpLhs, mapping);
528 IRMapping rhsMapping;
530 vectorOpRhs->getOperand(
531 getIndexPosition(contractOp.getRhs(), outerLoop) + 1),
534 vectorOpRhs->getOperand(
535 getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
538 rewriterNewInnerLoop.clone(*vectorOpRhs, rhsMapping);
540 SmallVector<Value> accumulators = createTiledDp(
541 rewriter, locNewInnerLoop, ops, lhsClone->getResult(0),
542 rhsClone->getResult(0), ipType, opType,
543 iterArgsNewInnerLoop, blockingFactor);
545 scf::YieldOp::create(rewriterNewInnerLoop, locNewInnerLoop,
549 scf::YieldOp::create(rewriterOuterLoop, locOuterLoop,
550 newInnerLoop.getResults());
555 if (loopLists.size() == 1) {
556 outerLoop = loopLists[0];
558 SmallVector<Value> loopItrArgs = createTileZeros(
559 rewriter, outerLoop.getLoc(), opType, outerLoop, ops.size());
560 newLoop = scf::ForOp::create(
561 rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(),
562 outerLoop.getUpperBound(), outerLoop.getStep(), loopItrArgs,
563 [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
564 Value ivOuterLoop,
ValueRange iterArgsOuterLoop) {
567 vectorOpLhs->getOperand(
568 getIndexPosition(contractOp.getLhs(), outerLoop) + 1),
571 auto lhsClone = rewriterOuterLoop.clone(*vectorOpLhs, mapping);
573 IRMapping rhsMapping;
575 vectorOpRhs->getOperand(
576 getIndexPosition(contractOp.getRhs(), outerLoop) + 1),
579 auto rhsClone = rewriterOuterLoop.clone(*vectorOpRhs, rhsMapping);
581 SmallVector<Value> accumulators = createTiledDp(
582 rewriter, locOuterLoop, ops, lhsClone->getResult(0),
583 rhsClone->getResult(0), ipType, opType, iterArgsOuterLoop,
586 scf::YieldOp::create(rewriterOuterLoop, locOuterLoop, accumulators);
593 auto bufferType = MemRefType::get({16, 16}, opType);
595 memref::AllocaOp::create(rewriter, outerLoop.getLoc(), bufferType);
597 SmallVector<Value> dps = newLoop.getResults();
598 for (
size_t i = 0; i < ops.size(); i++) {
599 vector::ContractionOp contOp = ops[i];
600 Operation *resultWriteOp =
607 amx::TileStoreOp::create(rewriter, outerLoop.getLoc(), bBuffer,
617 rewriter, outerLoop.getLoc(), c0, mBound, one,
ValueRange{},
618 [&](OpBuilder &builder, Location loc, Value iv,
ValueRange iterArgs) {
619 auto resultAcc = vector::LoadOp::create(
620 rewriter, loc, VectorType::get(16, opType), bBuffer,
623 Operation *accReadOp =
627 SmallVector<Value> indicesAcc;
629 llvm::TypeSwitch<Operation *>(accReadOp)
630 .Case<TransferReadOp, LoadOp>([&](
auto readOp) {
631 srcBuffAcc = readOp.getOperand(0);
633 auto indices = readOp.getIndices();
634 indicesAcc.reserve(
indices.size());
637 indices, std::back_inserter(indicesAcc),
638 [&](OpFoldResult ofr) {
644 Value sum = arith::AddIOp::create(builder, loc, iv, indicesAcc[0]);
645 indicesAcc[indicesAcc.size() - 2] = sum;
647 auto acc = vector::LoadOp::create(rewriter, loc,
648 VectorType::get(16, opType),
649 srcBuffAcc, indicesAcc);
652 addition = arith::AddFOp::create(rewriter, loc, resultAcc, acc);
655 addition = arith::AddIOp::create(rewriter, loc, resultAcc, acc);
657 vector::StoreOp::create(builder, loc, addition, srcBuffAcc,
660 scf::YieldOp::create(builder, outerLoop.getLoc());
663 rewriter.
eraseOp(resultWriteOp);
674 patterns.
add<VectorContractToAMXDotProduct>(patterns.
getContext());
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
IntegerType getIntegerType(unsigned width)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
mlir::x86::AMXTileType TileType
Operation * traceToVectorWriteLikeUserOperation(Value v)
bool isInVnniLayout(Operation *op, llvm::ArrayRef< AffineMap > indexingMaps, std::optional< unsigned > blockingFactor=std::nullopt)
Operation * traceToVectorReadLikeParentOperation(Value v)
void populateVectorContractToAMXDotProductPatterns(RewritePatternSet &patterns)
Include the generated interface declarations.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
SmallVector< int64_t, 2 > ReassociationIndices
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...