33static bool isNonUnitDimOperandShuffled(
Value nonUnitDimOperand) {
35 if (isa<vector::ShuffleOp>(defOp))
38 if (isa<vector::ShapeCastOp>(defOp)) {
40 if (isa<vector::ShuffleOp>(defOpShpCst))
53 if (mlir::isa<mlir::vector::ContractionOp>(user) ||
54 mlir::isa<mlir::scf::ForOp>(user)) {
65 mlir::vector::ContractionOp contractA,
66 mlir::vector::ContractionOp contractB,
68 mlir::VectorType Ty) {
82 auto elemTy = Ty.getElementType();
83 auto flatTy = mlir::VectorType::get(nonUnitDimAcc, elemTy);
85 if (elemTy.isSignlessInteger(8))
86 flatTy = mlir::VectorType::get({2, nonUnitDimAcc / 2}, elemTy);
93 srcBuff = readOp.getOperand(0);
95 auto indices = readOp.getIndices();
96 indexVals.reserve(
indices.size());
104 int64_t srcRank = (dyn_cast<ShapedType>(srcBuff.getType())).getRank();
105 Value padding = ub::PoisonOp::create(rewriter, loc, elemTy);
110 Value vec1 = vector::TransferReadOp::create(
111 rewriter, loc, flatTy, srcBuff, indexVals, padding, map, inBounds);
113 if (elemTy.isSignlessInteger(8))
114 vec1 = vector::ShapeCastOp::create(
115 rewriter, loc, VectorType::get(nonUnitDimAcc, elemTy), vec1);
117 unsigned int offset = 1;
118 if (elemTy.isSignlessInteger(8))
123 arith::AddIOp::create(rewriter, loc, rewriter.
getIndexType(), cOffset,
124 indexVals[indexVals.size() - 2]);
125 indexVals[indexVals.size() - 2] = nextIndx;
127 Value vec2 = vector::TransferReadOp::create(
128 rewriter, loc, flatTy, srcBuff, indexVals, padding, map, inBounds);
130 if (elemTy.isSignlessInteger(8))
131 vec2 = vector::ShapeCastOp::create(
132 rewriter, loc, VectorType::get(nonUnitDimAcc, elemTy), vec2);
134 flatTy = mlir::VectorType::get(nonUnitDimAcc, elemTy);
136 static constexpr int64_t maskLo_bf16[] = {
137 0, 32, 1, 33, 2, 34, 3, 35, 8, 40, 9, 41, 10, 42, 11, 43,
138 16, 48, 17, 49, 18, 50, 19, 51, 24, 56, 25, 57, 26, 58, 27, 59};
139 static constexpr int64_t maskHi_bf16[] = {
140 4, 36, 5, 37, 6, 38, 7, 39, 12, 44, 13, 45, 14, 46, 15, 47,
141 20, 52, 21, 53, 22, 54, 23, 55, 28, 60, 29, 61, 30, 62, 31, 63};
143 static constexpr int64_t maskLo_int8_avx2[] = {
144 0, 16, 32, 48, 1, 17, 33, 49, 2, 18, 34, 50, 3, 19, 35, 51,
145 8, 24, 40, 56, 9, 25, 41, 57, 10, 26, 42, 58, 11, 27, 43, 59};
146 static constexpr int64_t maskHi_int8_avx2[] = {
147 4, 20, 36, 52, 5, 21, 37, 53, 6, 22, 38, 54, 7, 23, 39, 55,
148 12, 28, 44, 60, 13, 29, 45, 61, 14, 30, 46, 62, 15, 31, 47, 63};
150 static constexpr int64_t maskLo_int8_avx10[] = {
151 0, 32, 64, 96, 1, 33, 65, 97, 2, 34, 66, 98, 3, 35, 67, 99,
152 8, 40, 72, 104, 9, 41, 73, 105, 10, 42, 74, 106, 11, 43, 75, 107,
153 16, 48, 80, 112, 17, 49, 81, 113, 18, 50, 82, 114, 19, 51, 83, 115,
154 24, 56, 88, 120, 25, 57, 89, 121, 26, 58, 90, 122, 27, 59, 91, 123};
155 static constexpr int64_t maskHi_int8_avx10[] = {
156 4, 36, 68, 100, 5, 37, 69, 101, 6, 38, 70, 102, 7, 39, 71, 103,
157 12, 44, 76, 108, 13, 45, 77, 109, 14, 46, 78, 110, 15, 47, 79, 111,
158 20, 52, 84, 116, 21, 53, 85, 117, 22, 54, 86, 118, 23, 55, 87, 119,
159 28, 60, 92, 124, 29, 61, 93, 125, 30, 62, 94, 126, 31, 63, 95, 127};
164 if (elemTy.isSignlessInteger(8)) {
168 if (nonUnitDimAcc == 32) {
174 auto shuffleLo = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, vec1,
176 auto shuffleHi = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, vec1,
179 auto newA = mlir::vector::ShapeCastOp::create(rewriter, loc, Ty, shuffleLo);
180 auto newB = mlir::vector::ShapeCastOp::create(rewriter, loc, Ty, shuffleHi);
182 rewriteUses(opA->
getResult(0), newA.getResult(), contractA, rewriter);
183 rewriteUses(opB->
getResult(0), newB.getResult(), contractB, rewriter);
217struct VectorContractToPackedTypeDotProduct
219 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
221 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
222 PatternRewriter &rewriter)
const override {
224 if (contractOp.getKind() != vector::CombiningKind::ADD)
226 "Expects add combining kind.");
228 VectorType lhsTy = contractOp.getLhsType();
229 if (!lhsTy.getElementType().isBF16() &&
230 !lhsTy.getElementType().isSignlessInteger(8))
232 contractOp,
"Only BF16/Int8 lowering is supported.");
234 unsigned int blockingFactor = lhsTy.getElementType().isBF16() ? 2 : 4;
237 contractOp.getIndexingMapsArray(), blockingFactor);
239 VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
243 ArrayRef<int64_t> accShape = accTy.getShape();
244 llvm::SmallVector<int64_t> nonUnitDimAcc;
245 llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
246 [](int64_t dim) {
return dim != 1; });
247 if (nonUnitDimAcc.size() != 1)
249 contractOp,
"A or B should be a non-unit dim in acc.");
251 int64_t nonUnitDimValue = nonUnitDimAcc.front();
254 if (lhsTy.getElementType().isBF16() && nonUnitDimValue != 4 &&
255 nonUnitDimValue != 8 && nonUnitDimValue != 16)
257 contractOp,
"BF16 dot-product operation expects non-unit (LHR or "
258 "RHS) dim and acc dim of size 4/8/16.");
260 if (lhsTy.getElementType().isSignlessInteger(8) && nonUnitDimValue != 4 &&
261 nonUnitDimValue != 8 && nonUnitDimValue != 16 &&
262 nonUnitDimAcc.front() == nonUnitDimValue)
264 contractOp,
"Int8 dot-product operation expects non-unit (LHR or "
265 "RHS) dim and acc dim of size 4/8/16.");
267 ArrayRef<int64_t> lhsShape = lhsTy.getShape();
268 llvm::SmallVector<int64_t> nonUnitDimLhs;
269 llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
270 [](int64_t dim) {
return dim != 1; });
272 VectorType rhsTy = contractOp.getRhsType();
273 ArrayRef<int64_t> rhsShape = rhsTy.getShape();
274 llvm::SmallVector<int64_t> nonUnitDimRhs;
275 llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
276 [](int64_t dim) {
return dim != 1; });
278 if ((nonUnitDimLhs.size() - 1) > 0 && (nonUnitDimRhs.size() - 1) > 0)
280 "Excepts unit dimensions for either "
281 "LHS or RHS shape.");
283 if ((nonUnitDimLhs.size() - 1) != 1 && (nonUnitDimRhs.size() - 1) != 1)
286 "Excepts a one non-unit A/B dimension for either LHS or RHS shape.");
288 bool rhsHasMultipleNonUnitDims = (nonUnitDimRhs.size() - 1) > 0;
289 int64_t extraFlatDim = rhsHasMultipleNonUnitDims ? nonUnitDimLhs.front()
290 : nonUnitDimRhs.front();
292 if (!isVnni && (extraFlatDim != blockingFactor))
294 contractOp,
"The K or reduction dim for flat layout should be 2/4.");
296 if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()) ||
297 (lhsTy.getElementType().isSignlessInteger(8) &&
298 !accTy.getElementType().isSignlessInteger(32)))
300 "Only F32 for BF16 or Int32 for Int8 "
301 "accumulation type is supported.");
303 Value unitDimOperand =
304 rhsHasMultipleNonUnitDims ? contractOp.getLhs() : contractOp.getRhs();
305 Value nonUnitDimOperand =
306 rhsHasMultipleNonUnitDims ? contractOp.getRhs() : contractOp.getLhs();
311 vector::ContractionOp pairContractOp;
312 Operation *nextOp = contractOp;
313 while ((nextOp = nextOp->getNextNode())) {
314 auto contOp = dyn_cast<vector::ContractionOp>(nextOp);
320 rhsHasMultipleNonUnitDims,
322 pairContractOp = contOp;
332 if (!pairContractOp &&
333 (!isNonUnitDimOperandShuffled(nonUnitDimOperand) || accRead))
335 "Could not find a contract pair");
341 Operation *accReadOp0 =
343 Operation *accReadOp1 =
348 Operation *resultWriteOp0 =
350 Operation *resultWriteOp1 =
353 if (!accReadOp0 || !accReadOp1)
356 "Operands doesn't have load or transfer_read as it's parent op");
358 if (!resultWriteOp0 || !resultWriteOp1)
361 "The use of contract operations are neither vector.store "
362 "or transfer_write or has multiple users.");
364 if (contractOp->getBlock() == accReadOp1->
getBlock() &&
365 contractOp->isBeforeInBlock(accReadOp1))
368 "The load/read operation of pair contract operation is "
369 "after the contractOp");
371 if (pairContractOp->getBlock() == resultWriteOp0->
getBlock() &&
374 contractOp,
"The store/write operation of contract operation is "
375 "before the pair contract operation");
377 LogicalResult readShuffle =
379 pairContractOp, nonUnitDimValue, accTy);
383 contractOp,
"Accumulator read is not by transfer_read or load");
387 rewriter, resultWriteOp0, resultWriteOp1, nonUnitDimValue, accTy);
392 "Write to accumulator is not by transfer_write or store");
395 if (!isNonUnitDimOperandShuffled(nonUnitDimOperand)) {
396 Value nonUnitDimOperandPairContract = rhsHasMultipleNonUnitDims
397 ? pairContractOp.getRhs()
398 : pairContractOp.getLhs();
401 Operation *nonUnitDimReadOp =
403 Operation *nonUnitDimReadOpPairContract =
406 if (!nonUnitDimReadOp || !nonUnitDimReadOpPairContract)
408 contractOp,
"Could not find a valid contract pair");
410 VectorType nonUnitDimTy = rhsHasMultipleNonUnitDims
411 ? contractOp.getRhsType()
412 : contractOp.getLhsType();
414 packNonUnitDimOperandToVNNI(
415 rewriter, nonUnitDimReadOp, nonUnitDimReadOpPairContract,
416 contractOp, pairContractOp, blockingFactor * nonUnitDimValue,
419 nonUnitDimOperand = rhsHasMultipleNonUnitDims ? contractOp.getRhs()
420 : contractOp.getLhs();
425 auto loc = contractOp.getLoc();
426 auto castAcc = vector::ShapeCastOp::create(
428 VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
429 contractOp.getAcc());
431 VectorType nonUnitDimTy = rhsHasMultipleNonUnitDims
432 ? contractOp.getRhsType()
433 : contractOp.getLhsType();
434 VectorType unitDimTy = rhsHasMultipleNonUnitDims ? contractOp.getLhsType()
435 : contractOp.getRhsType();
439 auto castNonUnitDim = vector::ShapeCastOp::create(
441 VectorType::get(blockingFactor * nonUnitDimValue,
442 nonUnitDimTy.getElementType()),
445 auto castUnitDim = vector::ShapeCastOp::create(
447 VectorType::get(blockingFactor, unitDimTy.getElementType()),
449 auto bitcastUnitDim = vector::BitCastOp::create(
452 auto broadcastUnitDim = vector::BroadcastOp::create(
456 auto bitcastUnitDimPkType = vector::BitCastOp::create(
457 rewriter, loc, castNonUnitDim.getResult().getType(), broadcastUnitDim);
459 if (lhsTy.getElementType().isBF16()) {
460 dp = x86::avx512::DotBF16Op::create(
462 VectorType::get(nonUnitDimValue, rewriter.
getF32Type()), castAcc,
463 bitcastUnitDimPkType, castNonUnitDim);
466 if (lhsTy.getElementType().isSignlessInteger(8)) {
467 if (nonUnitDimAcc.front() == 16) {
468 dp = x86::avx10::AVX10DotInt8Op::create(
471 castAcc, bitcastUnitDimPkType, castNonUnitDim);
473 dp = x86::avx::DotInt8Op::create(
476 castAcc, bitcastUnitDimPkType, castNonUnitDim);
483 auto castDp = vector::ShapeCastOp::create(rewriter, loc, accTy, dp);
493 patterns.
add<VectorContractToPackedTypeDotProduct>(patterns.
getContext());
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Block * getBlock()
Returns the operation block that contains this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void moveOpAfter(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right after existingOp which may be in the...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
LogicalResult shuffleBeforeWriteLikeOp(PatternRewriter &rewriter, Operation *opA, Operation *opB, int64_t nonUnitDimAcc, VectorType accTy)
Operation * traceToVectorWriteLikeUserOperation(Value v)
void populateVectorContractToPackedTypeDotProductPatterns(RewritePatternSet &patterns)
bool isInVnniLayout(Operation *op, llvm::ArrayRef< AffineMap > indexingMaps, std::optional< unsigned > blockingFactor=std::nullopt)
Operation * traceToVectorReadLikeParentOperation(Value v)
LogicalResult shuffleAfterReadLikeOp(PatternRewriter &rewriter, Operation *opA, Operation *opB, vector::ContractionOp contractA, vector::ContractionOp contractB, int64_t nonUnitDimAcc, VectorType accTy)
bool validatePairVectorContract(vector::ContractionOp contractOp, vector::ContractionOp pairContOp, bool rhsHasMultipleNonUnitDims, int64_t nonUnitDimValue)
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...