19#include "llvm/ADT/TypeSwitch.h"
20#include "llvm/Support/Casting.h"
22#include "llvm/ADT/ArrayRef.h"
28static FailureOr<SmallVector<mlir::utils::IteratorType>>
33 map.
getNumDims(), mlir::utils::IteratorType::reduction);
35 if (
auto dim = dyn_cast<AffineDimExpr>(expr))
36 iterators[dim.getPosition()] = mlir::utils::IteratorType::parallel;
43 std::optional<unsigned> blockingFactor) {
45 FailureOr<linalg::ContractionDimensions> dims =
52 auto typeA = dyn_cast<ShapedType>(matA.getType());
53 auto typeB = dyn_cast<ShapedType>(matB.getType());
54 unsigned rankA = typeA.getRank();
55 unsigned rankB = typeB.getRank();
57 if (rankA < 3 || rankB < 3)
62 if (dims->k.size() < 2)
71 if (failed(maybeIters))
77 auto vnniDimA = dyn_cast<AffineDimExpr>(mapA.
getResult(rankA - 1));
78 auto vnniDimB = dyn_cast<AffineDimExpr>(mapB.
getResult(rankB - 1));
79 if (!vnniDimA || !vnniDimB || vnniDimA != vnniDimB ||
80 iteratorTypes[vnniDimA.getPosition()] !=
81 mlir::utils::IteratorType::reduction)
83 auto redDimA = dyn_cast<AffineDimExpr>(mapA.
getResult(rankA - 2));
84 auto redDimB = dyn_cast<AffineDimExpr>(mapB.
getResult(rankB - 3));
85 if (!redDimA || !redDimB || redDimA != redDimB ||
86 iteratorTypes[redDimA.getPosition()] !=
87 mlir::utils::IteratorType::reduction)
89 auto parallelDimB = dyn_cast<AffineDimExpr>(mapB.
getResult(rankB - 2));
90 if (!parallelDimB || iteratorTypes[parallelDimB.getPosition()] !=
91 mlir::utils::IteratorType::parallel)
98 auto vnniDimSize = typeB.getShape().back();
99 if (vnniDimSize == ShapedType::kDynamic || vnniDimSize == 0 ||
100 vnniDimSize % 2 != 0)
102 if (typeA.getShape().back() != vnniDimSize)
104 if (blockingFactor && vnniDimSize != *blockingFactor)
108 if (typeA.getShape().end()[-2] != typeB.getShape().end()[-3])
121 assert((nonUnitDimAcc == 8 || nonUnitDimAcc == 16) &&
122 "Unsupported nonUnitDimAcc value");
125 static constexpr int64_t maskLo8[] = {0, 8, 1, 9, 2, 10, 3, 11};
126 static constexpr int64_t maskHi8[] = {4, 12, 5, 13, 6, 14, 7, 15};
129 static constexpr int64_t maskLo8_avx2_int8[] = {0, 1, 2, 3, 8, 9, 10, 11};
130 static constexpr int64_t maskHi8_avx2_int8[] = {4, 5, 6, 7, 12, 13, 14, 15};
133 static constexpr int64_t maskLo16[] = {0, 1, 2, 3, 16, 17, 18, 19,
134 4, 5, 6, 7, 20, 21, 22, 23};
135 static constexpr int64_t maskHi16[] = {8, 9, 10, 11, 24, 25, 26, 27,
136 12, 13, 14, 15, 28, 29, 30, 31};
138 if (nonUnitDimAcc == 16)
139 return {maskLo16, maskHi16};
142 return {maskLo8_avx2_int8, maskHi8_avx2_int8};
144 return {maskLo8, maskHi8};
158 if (isa<vector::TransferReadOp, vector::LoadOp>(defOp))
165 if (
auto barg = dyn_cast<BlockArgument>(v)) {
166 auto *parentOp = barg.getOwner()->getParentOp();
168 if (
auto forOp = dyn_cast<scf::ForOp>(parentOp)) {
169 unsigned argNum = barg.getArgNumber();
175 unsigned iterIdx = argNum - 1;
176 v = forOp.getInitArgs()[iterIdx];
203 if (isa<vector::TransferWriteOp>(user) || isa<vector::StoreOp>(user))
206 if (isa<vector::ShapeCastOp, vector::ShuffleOp>(user))
210 if (
auto yield = dyn_cast<scf::YieldOp>(user)) {
212 unsigned idx = use.getOperandNumber();
220 if (
auto forOp = dyn_cast<scf::ForOp>(user)) {
221 unsigned idx = use.getOperandNumber();
244 vector::ContractionOp contractA,
245 vector::ContractionOp contractB,
246 int64_t nonUnitDimAcc, VectorType accTy) {
248 if (!isa<vector::TransferReadOp, vector::LoadOp>(opA) ||
249 !isa<vector::TransferReadOp, vector::LoadOp>(opB)) {
258 auto elemTy = accTy.getElementType();
259 auto flatTy = VectorType::get(nonUnitDimAcc, elemTy);
262 vector::ShapeCastOp::create(rewriter, loc, flatTy, opA->
getResult(0));
264 vector::ShapeCastOp::create(rewriter, loc, flatTy, opB->
getResult(0));
267 nonUnitDimAcc, (elemTy.isSignlessInteger(32) && nonUnitDimAcc == 8));
269 auto shuffleLo = vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
270 castB, masks.maskLo);
271 auto shuffleHi = vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
272 castB, masks.maskHi);
274 auto newAccA = vector::ShapeCastOp::create(rewriter, loc, accTy, shuffleLo);
275 auto newAccB = vector::ShapeCastOp::create(rewriter, loc, accTy, shuffleHi);
279 return isa<vector::ContractionOp, scf::ForOp>(use.getOwner());
284 return isa<vector::ContractionOp, scf::ForOp>(use.getOwner());
298 if (
auto write = dyn_cast<vector::TransferWriteOp>(op))
299 return write.getVector();
300 if (
auto store = dyn_cast<vector::StoreOp>(op))
301 return store.getValueToStore();
305 Value vecA = getWrittenVector(opA);
306 Value vecB = getWrittenVector(opB);
317 auto elemTy = accTy.getElementType();
318 auto flatTy = VectorType::get(nonUnitDimAcc, elemTy);
321 auto castA = vector::ShapeCastOp::create(rewriter, loc, flatTy, vecA);
322 auto castB = vector::ShapeCastOp::create(rewriter, loc, flatTy, vecB);
326 nonUnitDimAcc, (elemTy.isSignlessInteger(32) && nonUnitDimAcc == 8));
328 auto shuffledLo = vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
329 castB, masks.maskLo);
330 auto shuffledHi = vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
331 castB, masks.maskHi);
334 auto newVecA = vector::ShapeCastOp::create(rewriter, loc, accTy, shuffledLo);
335 auto newVecB = vector::ShapeCastOp::create(rewriter, loc, accTy, shuffledHi);
339 [&]() { opA->
setOperand(0, newVecA.getResult()); });
341 [&]() { opB->
setOperand(0, newVecB.getResult()); });
353 vector::ContractionOp pairContOp,
354 bool rhsHasMultipleNonUnitDims,
356 if (contractOp == pairContOp)
359 if (rhsHasMultipleNonUnitDims &&
360 !(contractOp.getLhs() == pairContOp.getLhs()))
363 if (!rhsHasMultipleNonUnitDims &&
364 !(contractOp.getRhs() == pairContOp.getRhs()))
367 auto nonUnitOperand =
368 rhsHasMultipleNonUnitDims ? contractOp.getRhs() : contractOp.getLhs();
369 auto nonUnitOperandPairContOp =
370 rhsHasMultipleNonUnitDims ? pairContOp.getRhs() : pairContOp.getLhs();
375 .Case<vector::TransferReadOp, vector::LoadOp>([&](
auto readOp) {
376 srcBuff = readOp.getOperand(0);
378 readOp.getIndices().end());
380 .Case<vector::ShapeCastOp>([&](vector::ShapeCastOp op) {
381 srcBuff = op.getSource();
385 Value srcBuffPairContOp;
388 .Case<vector::TransferReadOp, vector::LoadOp>([&](
auto readOp) {
389 srcBuffPairContOp = readOp.getOperand(0);
391 readOp.getIndices().begin(), readOp.getIndices().end());
393 .Case<vector::ShapeCastOp>([&](vector::ShapeCastOp op) {
394 srcBuffPairContOp = op.getSource();
398 if (!srcBuff || !srcBuffPairContOp)
402 auto shuffleHw = srcBuffPairContOp.getDefiningOp<vector::ShuffleOp>();
404 if (shuffleLw && shuffleHw)
405 return shuffleLw.getV1() == shuffleHw.getV1() &&
406 shuffleLw.getV2() == shuffleHw.getV2();
408 if (srcBuff != srcBuffPairContOp)
411 bool oneConstantOffset =
false;
412 for (
size_t i = 0; i < indexVals.size(); i++) {
414 if (indexVals[i] == indexValsPairContOp[i])
423 if ((*v1 - *v0) != nonUnitDimValue)
426 oneConstantOffset =
true;
429 return oneConstantOffset;
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
AffineExpr getResult(unsigned idx) const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
void setOperand(unsigned idx, Value value)
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
result_range getResults()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
unsigned getNumUses() const
This method computes the number of uses of this Value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
LogicalResult shuffleBeforeWriteLikeOp(PatternRewriter &rewriter, Operation *opA, Operation *opB, int64_t nonUnitDimAcc, VectorType accTy)
Operation * traceToVectorWriteLikeUserOperation(Value v)
static FailureOr< SmallVector< mlir::utils::IteratorType > > inferIteratorsFromOutMap(AffineMap map)
bool isInVnniLayout(Operation *op, llvm::ArrayRef< AffineMap > indexingMaps, std::optional< unsigned > blockingFactor=std::nullopt)
Operation * traceToVectorReadLikeParentOperation(Value v)
ShuffleMasks getShuffleMasks(int64_t nonUnitDimAcc, bool isInt8Avx2)
LogicalResult shuffleAfterReadLikeOp(PatternRewriter &rewriter, Operation *opA, Operation *opB, vector::ContractionOp contractA, vector::ContractionOp contractB, int64_t nonUnitDimAcc, VectorType accTy)
bool validatePairVectorContract(vector::ContractionOp contractOp, vector::ContractionOp pairContOp, bool rhsHasMultipleNonUnitDims, int64_t nonUnitDimValue)
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
llvm::ArrayRef< int64_t > maskHi
llvm::ArrayRef< int64_t > maskLo