19#include "llvm/ADT/TypeSwitch.h"
20#include "llvm/Support/Casting.h"
22#include "llvm/ADT/ArrayRef.h"
28static FailureOr<SmallVector<mlir::utils::IteratorType>>
33 map.
getNumDims(), mlir::utils::IteratorType::reduction);
35 if (
auto dim = dyn_cast<AffineDimExpr>(expr))
36 iterators[dim.getPosition()] = mlir::utils::IteratorType::parallel;
43 std::optional<unsigned> blockingFactor) {
45 FailureOr<linalg::ContractionDimensions> dims =
52 auto typeA = dyn_cast<ShapedType>(matA.getType());
53 auto typeB = dyn_cast<ShapedType>(matB.getType());
54 unsigned rankA = typeA.getRank();
55 unsigned rankB = typeB.getRank();
57 if (rankA < 3 || rankB < 3)
62 if (dims->k.size() < 2)
71 if (failed(maybeIters))
77 auto vnniDimA = dyn_cast<AffineDimExpr>(mapA.
getResult(rankA - 1));
78 auto vnniDimB = dyn_cast<AffineDimExpr>(mapB.
getResult(rankB - 1));
79 if (!vnniDimA || !vnniDimB || vnniDimA != vnniDimB ||
80 iteratorTypes[vnniDimA.getPosition()] !=
81 mlir::utils::IteratorType::reduction)
83 auto redDimA = dyn_cast<AffineDimExpr>(mapA.
getResult(rankA - 2));
84 auto redDimB = dyn_cast<AffineDimExpr>(mapB.
getResult(rankB - 3));
85 if (!redDimA || !redDimB || redDimA != redDimB ||
86 iteratorTypes[redDimA.getPosition()] !=
87 mlir::utils::IteratorType::reduction)
89 auto parallelDimB = dyn_cast<AffineDimExpr>(mapB.
getResult(rankB - 2));
90 if (!parallelDimB || iteratorTypes[parallelDimB.getPosition()] !=
91 mlir::utils::IteratorType::parallel)
98 auto vnniDimSize = typeB.getShape().back();
99 if (vnniDimSize == ShapedType::kDynamic || vnniDimSize == 0 ||
100 vnniDimSize % 2 != 0)
102 if (typeA.getShape().back() != vnniDimSize)
104 if (blockingFactor && vnniDimSize != *blockingFactor)
108 if (typeA.getShape().end()[-2] != typeB.getShape().end()[-3])
121 assert((nonUnitDimAcc == 8 || nonUnitDimAcc == 16) &&
122 "Unsupported nonUnitDimAcc value");
124 static constexpr int64_t maskLo8[] = {0, 8, 1, 9, 2, 10, 3, 11};
125 static constexpr int64_t maskHi8[] = {4, 12, 5, 13, 6, 14, 7, 15};
128 static constexpr int64_t maskLo16[] = {0, 1, 2, 3, 16, 17, 18, 19,
129 4, 5, 6, 7, 20, 21, 22, 23};
130 static constexpr int64_t maskHi16[] = {8, 9, 10, 11, 24, 25, 26, 27,
131 12, 13, 14, 15, 28, 29, 30, 31};
133 if (nonUnitDimAcc == 16)
134 return {maskLo16, maskHi16};
136 return {maskLo8, maskHi8};
150 if (isa<vector::TransferReadOp, vector::LoadOp>(defOp))
157 if (
auto barg = dyn_cast<BlockArgument>(v)) {
158 auto *parentOp = barg.getOwner()->getParentOp();
160 if (
auto forOp = dyn_cast<scf::ForOp>(parentOp)) {
161 unsigned argNum = barg.getArgNumber();
167 unsigned iterIdx = argNum - 1;
168 v = forOp.getInitArgs()[iterIdx];
195 if (isa<vector::TransferWriteOp>(user) || isa<vector::StoreOp>(user))
198 if (isa<vector::ShapeCastOp, vector::ShuffleOp>(user))
202 if (
auto yield = dyn_cast<scf::YieldOp>(user)) {
204 unsigned idx = use.getOperandNumber();
212 if (
auto forOp = dyn_cast<scf::ForOp>(user)) {
213 unsigned idx = use.getOperandNumber();
236 vector::ContractionOp contractA,
237 vector::ContractionOp contractB,
238 int64_t nonUnitDimAcc, VectorType accTy) {
240 if (!isa<vector::TransferReadOp, vector::LoadOp>(opA) ||
241 !isa<vector::TransferReadOp, vector::LoadOp>(opB)) {
250 auto elemTy = accTy.getElementType();
251 auto flatTy = VectorType::get(nonUnitDimAcc, elemTy);
254 vector::ShapeCastOp::create(rewriter, loc, flatTy, opA->
getResult(0));
256 vector::ShapeCastOp::create(rewriter, loc, flatTy, opB->
getResult(0));
260 auto shuffleLo = vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
261 castB, masks.maskLo);
262 auto shuffleHi = vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
263 castB, masks.maskHi);
265 auto newAccA = vector::ShapeCastOp::create(rewriter, loc, accTy, shuffleLo);
266 auto newAccB = vector::ShapeCastOp::create(rewriter, loc, accTy, shuffleHi);
270 return isa<vector::ContractionOp, scf::ForOp>(use.getOwner());
275 return isa<vector::ContractionOp, scf::ForOp>(use.getOwner());
289 if (
auto write = dyn_cast<vector::TransferWriteOp>(op))
290 return write.getVector();
291 if (
auto store = dyn_cast<vector::StoreOp>(op))
292 return store.getValueToStore();
296 Value vecA = getWrittenVector(opA);
297 Value vecB = getWrittenVector(opB);
308 auto elemTy = accTy.getElementType();
309 auto flatTy = VectorType::get(nonUnitDimAcc, elemTy);
312 auto castA = vector::ShapeCastOp::create(rewriter, loc, flatTy, vecA);
313 auto castB = vector::ShapeCastOp::create(rewriter, loc, flatTy, vecB);
318 auto shuffledLo = vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
319 castB, masks.maskLo);
320 auto shuffledHi = vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
321 castB, masks.maskHi);
324 auto newVecA = vector::ShapeCastOp::create(rewriter, loc, accTy, shuffledLo);
325 auto newVecB = vector::ShapeCastOp::create(rewriter, loc, accTy, shuffledHi);
341 vector::ContractionOp pairContOp,
342 bool rhsHasMultipleNonUnitDims,
344 if (rhsHasMultipleNonUnitDims &&
345 !(contractOp.getLhs() == pairContOp.getLhs()))
348 if (!rhsHasMultipleNonUnitDims &&
349 !(contractOp.getRhs() == pairContOp.getRhs()))
352 auto nonUnitOperand =
353 rhsHasMultipleNonUnitDims ? contractOp.getRhs() : contractOp.getLhs();
354 auto nonUnitOperandPairContOp =
355 rhsHasMultipleNonUnitDims ? pairContOp.getRhs() : pairContOp.getLhs();
360 .Case<vector::TransferReadOp, vector::LoadOp>([&](
auto readOp) {
361 srcBuff = readOp.getOperand(0);
363 readOp.getIndices().end());
365 .Case<vector::ShapeCastOp>([&](vector::ShapeCastOp op) {
366 srcBuff = op.getSource();
370 Value srcBuffPairContOp;
373 .Case<vector::TransferReadOp, vector::LoadOp>([&](
auto readOp) {
374 srcBuffPairContOp = readOp.getOperand(0);
376 readOp.getIndices().begin(), readOp.getIndices().end());
378 .Case<vector::ShapeCastOp>([&](vector::ShapeCastOp op) {
379 srcBuffPairContOp = op.getSource();
383 if (!srcBuff || !srcBuffPairContOp)
387 auto shuffleHw = srcBuffPairContOp.getDefiningOp<vector::ShuffleOp>();
389 if (shuffleLw && shuffleHw)
390 return shuffleLw.getV1() == shuffleHw.getV1() &&
391 shuffleLw.getV2() == shuffleHw.getV2();
393 if (srcBuff != srcBuffPairContOp)
396 for (
size_t i = 0; i < indexVals.size(); i++) {
406 if ((*v1 - *v0) != nonUnitDimValue)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
AffineExpr getResult(unsigned idx) const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
void setOperand(unsigned idx, Value value)
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
result_range getResults()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
unsigned getNumUses() const
This method computes the number of uses of this Value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
Operation * traceToVectorReadLikeParentOperation(Value v)
static FailureOr< SmallVector< mlir::utils::IteratorType > > inferIteratorsFromOutMap(AffineMap map)
Operation * traceToVectorWriteLikeUserOperation(Value v)
ShuffleMasks getShuffleMasks(int64_t nonUnitDimAcc)
bool validatePairVectorContract(vector::ContractionOp contractOp, vector::ContractionOp pairContOp, bool rhsHasMultipleNonUnitDims, int64_t nonUnitDimValue)
LogicalResult shuffleBeforeWriteLikeOp(PatternRewriter &rewriter, Operation *opA, Operation *opB, int64_t nonUnitDimAcc, VectorType accTy)
LogicalResult shuffleAfterReadLikeOp(PatternRewriter &rewriter, Operation *opA, Operation *opB, vector::ContractionOp contractA, vector::ContractionOp contractB, int64_t nonUnitDimAcc, VectorType accTy)
bool isInVnniLayout(Operation *op, llvm::ArrayRef< AffineMap > indexingMaps, std::optional< unsigned > blockingFactor=std::nullopt)
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
llvm::ArrayRef< int64_t > maskLo
llvm::ArrayRef< int64_t > maskHi