33static bool isNonUnitDimOperandShuffled(
Value nonUnitDimOperand) {
35 if (isa<vector::ShuffleOp>(defOp))
38 if (isa<vector::ShapeCastOp>(defOp)) {
40 if (isa<vector::ShuffleOp>(defOpShpCst))
53 if (mlir::isa<mlir::vector::ContractionOp>(user) ||
54 mlir::isa<mlir::scf::ForOp>(user)) {
65 mlir::vector::ContractionOp contractA,
66 mlir::vector::ContractionOp contractB,
68 mlir::VectorType Ty) {
82 auto elemTy = Ty.getElementType();
83 auto flatTy = mlir::VectorType::get(nonUnitDimAcc, elemTy);
90 srcBuff = readOp.getOperand(0);
92 auto indices = readOp.getIndices();
93 indexVals.reserve(
indices.size());
101 int64_t srcRank = (dyn_cast<ShapedType>(srcBuff.
getType())).getRank();
102 Value padding = ub::PoisonOp::create(rewriter, loc, elemTy);
107 auto vec1 = vector::TransferReadOp::create(rewriter, loc, flatTy, srcBuff,
108 indexVals, padding, map, inBounds);
110 unsigned int offset = 1;
111 if (elemTy.isSignlessInteger(8))
116 arith::AddIOp::create(rewriter, loc, rewriter.
getIndexType(), cOffset,
117 indexVals[indexVals.size() - 2]);
118 indexVals[indexVals.size() - 2] = nextIndx;
120 auto vec2 = vector::TransferReadOp::create(rewriter, loc, flatTy, srcBuff,
121 indexVals, padding, map, inBounds);
123 static constexpr int64_t maskLo_bf16[] = {
124 0, 32, 1, 33, 2, 34, 3, 35, 8, 40, 9, 41, 10, 42, 11, 43,
125 16, 48, 17, 49, 18, 50, 19, 51, 24, 56, 25, 57, 26, 58, 27, 59};
126 static constexpr int64_t maskHi_bf16[] = {
127 4, 36, 5, 37, 6, 38, 7, 39, 12, 44, 13, 45, 14, 46, 15, 47,
128 20, 52, 21, 53, 22, 54, 23, 55, 28, 60, 29, 61, 30, 62, 31, 63};
130 static constexpr int64_t maskLo_int8_avx2[] = {
131 0, 16, 32, 48, 1, 17, 33, 49, 2, 18, 34, 50, 3, 19, 35, 51,
132 8, 24, 40, 56, 9, 25, 41, 57, 10, 26, 42, 58, 11, 27, 43, 59};
133 static constexpr int64_t maskHi_int8_avx2[] = {
134 4, 20, 36, 52, 5, 21, 37, 53, 6, 22, 38, 54, 7, 23, 39, 55,
135 12, 28, 44, 60, 13, 29, 45, 61, 14, 30, 46, 62, 15, 31, 47, 63};
137 static constexpr int64_t maskLo_int8_avx10[] = {
138 0, 32, 64, 96, 1, 33, 65, 97, 2, 34, 66, 98, 3, 35, 67, 99,
139 8, 40, 72, 104, 9, 41, 73, 105, 10, 42, 74, 106, 11, 43, 75, 107,
140 16, 48, 80, 112, 17, 49, 81, 113, 18, 50, 82, 114, 19, 51, 83, 115,
141 24, 56, 88, 120, 25, 57, 89, 121, 26, 58, 90, 122, 27, 59, 91, 123};
142 static constexpr int64_t maskHi_int8_avx10[] = {
143 4, 36, 68, 100, 5, 37, 69, 101, 6, 38, 70, 102, 7, 39, 71, 103,
144 12, 44, 76, 108, 13, 45, 77, 109, 14, 46, 78, 110, 15, 47, 79, 111,
145 20, 52, 84, 116, 21, 53, 85, 117, 22, 54, 86, 118, 23, 55, 87, 119,
146 28, 60, 92, 124, 29, 61, 93, 125, 30, 62, 94, 126, 31, 63, 95, 127};
151 if (elemTy.isSignlessInteger(8)) {
155 if (nonUnitDimAcc == 32) {
161 auto shuffleLo = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, vec1,
163 auto shuffleHi = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, vec1,
166 auto newA = mlir::vector::ShapeCastOp::create(rewriter, loc, Ty, shuffleLo);
167 auto newB = mlir::vector::ShapeCastOp::create(rewriter, loc, Ty, shuffleHi);
169 rewriteUses(opA->
getResult(0), newA.getResult(), contractA, rewriter);
170 rewriteUses(opB->
getResult(0), newB.getResult(), contractB, rewriter);
204struct VectorContractToPackedTypeDotProduct
206 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
208 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
209 PatternRewriter &rewriter)
const override {
211 if (contractOp.getKind() != vector::CombiningKind::ADD)
213 "Expects add combining kind.");
215 VectorType lhsTy = contractOp.getLhsType();
216 if (!lhsTy.getElementType().isBF16() &&
217 !lhsTy.getElementType().isSignlessInteger(8))
219 contractOp,
"Only BF16/Int8 lowering is supported.");
221 unsigned int blockingFactor = lhsTy.getElementType().isBF16() ? 2 : 4;
224 contractOp.getIndexingMapsArray(), blockingFactor);
226 VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
230 ArrayRef<int64_t> accShape = accTy.getShape();
231 llvm::SmallVector<int64_t> nonUnitDimAcc;
232 llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
233 [](int64_t dim) {
return dim != 1; });
234 if (nonUnitDimAcc.size() != 1)
236 contractOp,
"A or B should be a non-unit dim in acc.");
238 int64_t nonUnitDimValue = nonUnitDimAcc.front();
241 if (lhsTy.getElementType().isBF16() && nonUnitDimValue != 4 &&
242 nonUnitDimValue != 8 && nonUnitDimValue != 16)
244 contractOp,
"BF16 dot-product operation expects non-unit (LHR or "
245 "RHS) dim and acc dim of size 4/8/16.");
247 if (lhsTy.getElementType().isSignlessInteger(8) && nonUnitDimValue != 4 &&
248 nonUnitDimValue != 8 && nonUnitDimValue != 16 &&
249 nonUnitDimAcc.front() == nonUnitDimValue)
251 contractOp,
"Int8 dot-product operation expects non-unit (LHR or "
252 "RHS) dim and acc dim of size 4/8/16.");
254 ArrayRef<int64_t> lhsShape = lhsTy.getShape();
255 llvm::SmallVector<int64_t> nonUnitDimLhs;
256 llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
257 [](int64_t dim) {
return dim != 1; });
259 VectorType rhsTy = contractOp.getRhsType();
260 ArrayRef<int64_t> rhsShape = rhsTy.getShape();
261 llvm::SmallVector<int64_t> nonUnitDimRhs;
262 llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
263 [](int64_t dim) {
return dim != 1; });
265 if ((nonUnitDimLhs.size() - 1) > 0 && (nonUnitDimRhs.size() - 1) > 0)
267 "Excepts unit dimensions for either "
268 "LHS or RHS shape.");
270 if ((nonUnitDimLhs.size() - 1) != 1 && (nonUnitDimRhs.size() - 1) != 1)
273 "Excepts a one non-unit A/B dimension for either LHS or RHS shape.");
275 bool rhsHasMultipleNonUnitDims = (nonUnitDimRhs.size() - 1) > 0;
276 int64_t extraFlatDim = rhsHasMultipleNonUnitDims ? nonUnitDimLhs.front()
277 : nonUnitDimRhs.front();
279 if (!isVnni && (extraFlatDim != blockingFactor))
281 contractOp,
"The K or reduction dim for flat layout should be 2/4.");
283 if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()) ||
284 (lhsTy.getElementType().isSignlessInteger(8) &&
285 !accTy.getElementType().isSignlessInteger(32)))
287 "Only F32 for BF16 or Int32 for Int8 "
288 "accumulation type is supported.");
290 Value unitDimOperand =
291 rhsHasMultipleNonUnitDims ? contractOp.getLhs() : contractOp.getRhs();
292 Value nonUnitDimOperand =
293 rhsHasMultipleNonUnitDims ? contractOp.getRhs() : contractOp.getLhs();
298 vector::ContractionOp pairContractOp;
299 Operation *nextOp = contractOp;
300 while ((nextOp = nextOp->getNextNode())) {
301 auto contOp = dyn_cast<vector::ContractionOp>(nextOp);
307 rhsHasMultipleNonUnitDims,
309 pairContractOp = contOp;
319 if (!pairContractOp &&
320 (!isNonUnitDimOperandShuffled(nonUnitDimOperand) || accRead))
322 "Could not find a contract pair");
328 Operation *accReadOp0 =
330 Operation *accReadOp1 =
335 Operation *resultWriteOp0 =
337 Operation *resultWriteOp1 =
340 if (!accReadOp0 || !accReadOp1)
343 "Operands doesn't have load or transfer_read as it's parent op");
345 if (!resultWriteOp0 || !resultWriteOp1)
348 "The use of contract operations are neither vector.store "
349 "or transfer_write or has multiple users.");
351 if (contractOp->getBlock() == accReadOp1->
getBlock() &&
352 contractOp->isBeforeInBlock(accReadOp1))
355 "The load/read operation of pair contract operation is "
356 "after the contractOp");
358 if (pairContractOp->getBlock() == resultWriteOp0->
getBlock() &&
361 contractOp,
"The store/write operation of contract operation is "
362 "before the pair contract operation");
364 LogicalResult readShuffle =
366 pairContractOp, nonUnitDimValue, accTy);
370 contractOp,
"Accumulator read is not by transfer_read or load");
374 rewriter, resultWriteOp0, resultWriteOp1, nonUnitDimValue, accTy);
379 "Write to accumulator is not by transfer_write or store");
382 if (!isNonUnitDimOperandShuffled(nonUnitDimOperand)) {
383 Value nonUnitDimOperandPairContract = rhsHasMultipleNonUnitDims
384 ? pairContractOp.getRhs()
385 : pairContractOp.getLhs();
388 Operation *nonUnitDimReadOp =
390 Operation *nonUnitDimReadOpPairContract =
393 if (!nonUnitDimReadOp || !nonUnitDimReadOpPairContract)
395 contractOp,
"Could not find a valid contract pair");
397 VectorType nonUnitDimTy = rhsHasMultipleNonUnitDims
398 ? contractOp.getRhsType()
399 : contractOp.getLhsType();
401 packNonUnitDimOperandToVNNI(
402 rewriter, nonUnitDimReadOp, nonUnitDimReadOpPairContract,
403 contractOp, pairContractOp, blockingFactor * nonUnitDimValue,
406 nonUnitDimOperand = rhsHasMultipleNonUnitDims ? contractOp.getRhs()
407 : contractOp.getLhs();
412 auto loc = contractOp.getLoc();
413 auto castAcc = vector::ShapeCastOp::create(
415 VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
416 contractOp.getAcc());
418 VectorType nonUnitDimTy = rhsHasMultipleNonUnitDims
419 ? contractOp.getRhsType()
420 : contractOp.getLhsType();
421 VectorType unitDimTy = rhsHasMultipleNonUnitDims ? contractOp.getLhsType()
422 : contractOp.getRhsType();
426 auto castNonUnitDim = vector::ShapeCastOp::create(
428 VectorType::get(blockingFactor * nonUnitDimValue,
429 nonUnitDimTy.getElementType()),
432 auto castUnitDim = vector::ShapeCastOp::create(
434 VectorType::get(blockingFactor, unitDimTy.getElementType()),
436 auto bitcastUnitDim = vector::BitCastOp::create(
439 auto broadcastUnitDim = vector::BroadcastOp::create(
443 auto bitcastUnitDimPkType = vector::BitCastOp::create(
444 rewriter, loc, castNonUnitDim.getResult().getType(), broadcastUnitDim);
446 if (lhsTy.getElementType().isBF16()) {
447 dp = x86::avx512::DotBF16Op::create(
449 VectorType::get(nonUnitDimValue, rewriter.
getF32Type()), castAcc,
450 bitcastUnitDimPkType, castNonUnitDim);
453 if (lhsTy.getElementType().isSignlessInteger(8)) {
454 if (nonUnitDimAcc.front() == 16) {
455 dp = x86::avx10::AVX10DotInt8Op::create(
458 castAcc, bitcastUnitDimPkType, castNonUnitDim);
460 dp = x86::avx::DotInt8Op::create(
463 castAcc, bitcastUnitDimPkType, castNonUnitDim);
470 auto castDp = vector::ShapeCastOp::create(rewriter, loc, accTy, dp);
480 patterns.
add<VectorContractToPackedTypeDotProduct>(patterns.
getContext());
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Block * getBlock()
Returns the operation block that contains this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void moveOpAfter(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right after existingOp which may be in the...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
LogicalResult shuffleBeforeWriteLikeOp(PatternRewriter &rewriter, Operation *opA, Operation *opB, int64_t nonUnitDimAcc, VectorType accTy)
Operation * traceToVectorWriteLikeUserOperation(Value v)
void populateVectorContractToPackedTypeDotProductPatterns(RewritePatternSet &patterns)
bool isInVnniLayout(Operation *op, llvm::ArrayRef< AffineMap > indexingMaps, std::optional< unsigned > blockingFactor=std::nullopt)
Operation * traceToVectorReadLikeParentOperation(Value v)
LogicalResult shuffleAfterReadLikeOp(PatternRewriter &rewriter, Operation *opA, Operation *opB, vector::ContractionOp contractA, vector::ContractionOp contractB, int64_t nonUnitDimAcc, VectorType accTy)
bool validatePairVectorContract(vector::ContractionOp contractOp, vector::ContractionOp pairContOp, bool rhsHasMultipleNonUnitDims, int64_t nonUnitDimValue)
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...