33static bool isNonUnitDimOperandShuffled(
Value nonUnitDimOperand) {
35 if (isa<vector::ShuffleOp>(defOp))
38 if (isa<vector::ShapeCastOp>(defOp)) {
40 if (isa<vector::ShuffleOp>(defOpShpCst))
54 if (mlir::isa<mlir::vector::ContractionOp>(user) ||
55 mlir::isa<mlir::scf::ForOp>(user)) {
66 mlir::vector::ContractionOp contractA,
67 mlir::vector::ContractionOp contractB,
69 mlir::VectorType Ty) {
75 auto elemTy = Ty.getElementType();
76 auto flatTy = mlir::VectorType::get(nonUnitDimAcc, elemTy);
78 auto castA = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy,
80 auto castB = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy,
83 static constexpr int64_t maskLo[] = {
84 0, 32, 1, 33, 2, 34, 3, 35, 8, 40, 9, 41, 10, 42, 11, 43,
85 16, 48, 17, 49, 18, 50, 19, 51, 24, 56, 25, 57, 26, 58, 27, 59};
86 static constexpr int64_t maskHi[] = {
87 4, 36, 5, 37, 6, 38, 7, 39, 12, 44, 13, 45, 14, 46, 15, 47,
88 20, 52, 21, 53, 22, 54, 23, 55, 28, 60, 29, 61, 30, 62, 31, 63};
90 auto shuffleLo = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
92 auto shuffleHi = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
95 auto newA = mlir::vector::ShapeCastOp::create(rewriter, loc, Ty, shuffleLo);
96 auto newB = mlir::vector::ShapeCastOp::create(rewriter, loc, Ty, shuffleHi);
98 rewriteUses(opA->
getResult(0), newA.getResult(), contractA, rewriter);
99 rewriteUses(opB->
getResult(0), newB.getResult(), contractB, rewriter);
133struct VectorContractToPackedTypeDotProduct
135 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
137 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
138 PatternRewriter &rewriter)
const override {
140 if (contractOp.getKind() != vector::CombiningKind::ADD)
142 "Expects add combining kind.");
144 VectorType lhsTy = contractOp.getLhsType();
145 if (!lhsTy.getElementType().isBF16() &&
146 !lhsTy.getElementType().isSignlessInteger(8))
148 contractOp,
"Only BF16/Int8 lowering is supported.");
150 unsigned int blockingFactor = lhsTy.getElementType().isBF16() ? 2 : 4;
153 contractOp.getIndexingMapsArray(), blockingFactor);
155 if (lhsTy.getElementType().isSignlessInteger(8) && !isVnni)
158 VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
162 ArrayRef<int64_t> accShape = accTy.getShape();
163 llvm::SmallVector<int64_t> nonUnitDimAcc;
164 llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
165 [](int64_t dim) {
return dim != 1; });
166 if (nonUnitDimAcc.size() != 1)
168 contractOp,
"A or B should be a non-unit dim in acc.");
170 int64_t nonUnitDimValue = nonUnitDimAcc.front();
173 if (lhsTy.getElementType().isBF16() && nonUnitDimValue != 4 &&
174 nonUnitDimValue != 8 && nonUnitDimValue != 16)
176 contractOp,
"BF16 dot-product operation expects non-unit (LHR or "
177 "RHS) dim and acc dim of size 4/8/16.");
179 if (lhsTy.getElementType().isSignlessInteger(8) && nonUnitDimValue != 4 &&
180 nonUnitDimValue != 8 && nonUnitDimValue != 16 &&
181 nonUnitDimAcc.front() == nonUnitDimValue)
183 contractOp,
"Int8 dot-product operation expects non-unit (LHR or "
184 "RHS) dim and acc dim of size 4/8/16.");
186 ArrayRef<int64_t> lhsShape = lhsTy.getShape();
187 llvm::SmallVector<int64_t> nonUnitDimLhs;
188 llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
189 [](int64_t dim) {
return dim != 1; });
191 VectorType rhsTy = contractOp.getRhsType();
192 ArrayRef<int64_t> rhsShape = rhsTy.getShape();
193 llvm::SmallVector<int64_t> nonUnitDimRhs;
194 llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
195 [](int64_t dim) {
return dim != 1; });
197 if ((nonUnitDimLhs.size() - 1) > 0 && (nonUnitDimRhs.size() - 1) > 0)
199 "Excepts unit dimensions for either "
200 "LHS or RHS shape.");
202 if ((nonUnitDimLhs.size() - 1) != 1 && (nonUnitDimRhs.size() - 1) != 1)
205 "Excepts a one non-unit A/B dimension for either LHS or RHS shape.");
207 bool rhsHasMultipleNonUnitDims = (nonUnitDimRhs.size() - 1) > 0;
208 int64_t extraFlatDim = rhsHasMultipleNonUnitDims ? nonUnitDimLhs.front()
209 : nonUnitDimRhs.front();
211 if (!isVnni && (extraFlatDim != blockingFactor))
213 contractOp,
"The K or reduction dim for flat layout should be 2.");
215 if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()) ||
216 (lhsTy.getElementType().isSignlessInteger(8) &&
217 !accTy.getElementType().isSignlessInteger(32)))
219 "Only F32 for BF16 or Int32 for Int8 "
220 "accumulation type is supported.");
222 Value unitDimOperand =
223 rhsHasMultipleNonUnitDims ? contractOp.getLhs() : contractOp.getRhs();
224 Value nonUnitDimOperand =
225 rhsHasMultipleNonUnitDims ? contractOp.getRhs() : contractOp.getLhs();
230 vector::ContractionOp pairContractOp;
231 Operation *nextOp = contractOp;
232 while ((nextOp = nextOp->getNextNode())) {
233 auto contOp = dyn_cast<vector::ContractionOp>(nextOp);
239 rhsHasMultipleNonUnitDims,
241 pairContractOp = contOp;
251 if (!pairContractOp &&
252 (!isNonUnitDimOperandShuffled(nonUnitDimOperand) || accRead))
254 "Could not find a contract pair");
260 Operation *accReadOp0 =
262 Operation *accReadOp1 =
267 Operation *resultWriteOp0 =
269 Operation *resultWriteOp1 =
272 if (!accReadOp0 || !accReadOp1)
275 "Operands doesn't have load or transfer_read as it's parent op");
277 if (!resultWriteOp0 || !resultWriteOp1)
280 "The use of contract operations are neither vector.store "
281 "or transfer_write or has multiple users.");
283 if (contractOp->getBlock() == accReadOp1->
getBlock() &&
284 contractOp->isBeforeInBlock(accReadOp1))
287 "The load/read operation of pair contract operation is "
288 "after the contractOp");
290 if (pairContractOp->getBlock() == resultWriteOp0->
getBlock() &&
293 contractOp,
"The store/write operation of contract operation is "
294 "before the pair contract operation");
296 LogicalResult readShuffle =
298 pairContractOp, nonUnitDimValue, accTy);
302 contractOp,
"Accumulator read is not by transfer_read or load");
306 rewriter, resultWriteOp0, resultWriteOp1, nonUnitDimValue, accTy);
311 "Write to accumulator is not by transfer_write or store");
314 if (!isNonUnitDimOperandShuffled(nonUnitDimOperand)) {
315 Value nonUnitDimOperandPairContract = rhsHasMultipleNonUnitDims
316 ? pairContractOp.getRhs()
317 : pairContractOp.getLhs();
320 Operation *nonUnitDimReadOp =
322 Operation *nonUnitDimReadOpPairContract =
325 if (!nonUnitDimReadOp || !nonUnitDimReadOpPairContract)
327 contractOp,
"Could not find a valid contract pair");
329 if (contractOp->getBlock() ==
330 nonUnitDimReadOpPairContract->
getBlock() &&
331 contractOp->isBeforeInBlock(nonUnitDimReadOpPairContract))
334 "The load/read operation of pair contract operation is "
335 "after the contractOp");
337 VectorType nonUnitDimTy = rhsHasMultipleNonUnitDims
338 ? contractOp.getRhsType()
339 : contractOp.getLhsType();
341 packNonUnitDimOperandToVNNI(
342 rewriter, nonUnitDimReadOp, nonUnitDimReadOpPairContract,
343 contractOp, pairContractOp, blockingFactor * nonUnitDimValue,
346 nonUnitDimOperand = rhsHasMultipleNonUnitDims ? contractOp.getRhs()
347 : contractOp.getLhs();
352 auto loc = contractOp.getLoc();
353 auto castAcc = vector::ShapeCastOp::create(
355 VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
356 contractOp.getAcc());
358 VectorType nonUnitDimTy = rhsHasMultipleNonUnitDims
359 ? contractOp.getRhsType()
360 : contractOp.getLhsType();
361 VectorType unitDimTy = rhsHasMultipleNonUnitDims ? contractOp.getLhsType()
362 : contractOp.getRhsType();
366 auto castNonUnitDim = vector::ShapeCastOp::create(
368 VectorType::get(blockingFactor * nonUnitDimValue,
369 nonUnitDimTy.getElementType()),
372 auto castUnitDim = vector::ShapeCastOp::create(
374 VectorType::get(blockingFactor, unitDimTy.getElementType()),
376 auto bitcastUnitDim = vector::BitCastOp::create(
379 auto broadcastUnitDim = vector::BroadcastOp::create(
383 auto bitcastUnitDimPkType = vector::BitCastOp::create(
384 rewriter, loc, castNonUnitDim.getResult().getType(), broadcastUnitDim);
386 if (lhsTy.getElementType().isBF16()) {
387 dp = x86vector::DotBF16Op::create(
389 VectorType::get(nonUnitDimValue, rewriter.
getF32Type()), castAcc,
390 bitcastUnitDimPkType, castNonUnitDim);
393 if (lhsTy.getElementType().isSignlessInteger(8)) {
394 if (nonUnitDimAcc.front() == 16) {
395 dp = x86vector::AVX10DotInt8Op::create(
398 castAcc, bitcastUnitDimPkType, castNonUnitDim);
400 dp = x86vector::DotInt8Op::create(
403 castAcc, bitcastUnitDimPkType, castNonUnitDim);
410 auto castDp = vector::ShapeCastOp::create(rewriter, loc, accTy, dp);
IntegerType getIntegerType(unsigned width)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Block * getBlock()
Returns the operation block that contains this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation * traceToVectorReadLikeParentOperation(Value v)
Operation * traceToVectorWriteLikeUserOperation(Value v)
bool validatePairVectorContract(vector::ContractionOp contractOp, vector::ContractionOp pairContOp, bool rhsHasMultipleNonUnitDims, int64_t nonUnitDimValue)
void populateVectorContractToPackedTypeDotProductPatterns(RewritePatternSet &patterns)
LogicalResult shuffleBeforeWriteLikeOp(PatternRewriter &rewriter, Operation *opA, Operation *opB, int64_t nonUnitDimAcc, VectorType accTy)
LogicalResult shuffleAfterReadLikeOp(PatternRewriter &rewriter, Operation *opA, Operation *opB, vector::ContractionOp contractA, vector::ContractionOp contractB, int64_t nonUnitDimAcc, VectorType accTy)
bool isInVnniLayout(Operation *op, llvm::ArrayRef< AffineMap > indexingMaps, std::optional< unsigned > blockingFactor=std::nullopt)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...