33static bool isNonUnitDimOperandShuffled(
Value nonUnitDimOperand) {
35 if (isa<vector::ShuffleOp>(defOp))
38 if (isa<vector::ShapeCastOp>(defOp)) {
40 if (isa<vector::ShuffleOp>(defOpShpCst))
53 if (mlir::isa<mlir::vector::ContractionOp>(user) ||
54 mlir::isa<mlir::scf::ForOp>(user)) {
65 mlir::vector::ContractionOp contractA,
66 mlir::vector::ContractionOp contractB,
68 mlir::VectorType Ty) {
82 auto elemTy = Ty.getElementType();
83 auto flatTy = mlir::VectorType::get(nonUnitDimAcc, elemTy);
85 auto castA = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy,
87 auto castB = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy,
90 static constexpr int64_t maskLo[] = {
91 0, 32, 1, 33, 2, 34, 3, 35, 8, 40, 9, 41, 10, 42, 11, 43,
92 16, 48, 17, 49, 18, 50, 19, 51, 24, 56, 25, 57, 26, 58, 27, 59};
93 static constexpr int64_t maskHi[] = {
94 4, 36, 5, 37, 6, 38, 7, 39, 12, 44, 13, 45, 14, 46, 15, 47,
95 20, 52, 21, 53, 22, 54, 23, 55, 28, 60, 29, 61, 30, 62, 31, 63};
97 auto shuffleLo = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
99 auto shuffleHi = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
102 auto newA = mlir::vector::ShapeCastOp::create(rewriter, loc, Ty, shuffleLo);
103 auto newB = mlir::vector::ShapeCastOp::create(rewriter, loc, Ty, shuffleHi);
105 rewriteUses(opA->
getResult(0), newA.getResult(), contractA, rewriter);
106 rewriteUses(opB->
getResult(0), newB.getResult(), contractB, rewriter);
140struct VectorContractToPackedTypeDotProduct
142 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
144 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
145 PatternRewriter &rewriter)
const override {
147 if (contractOp.getKind() != vector::CombiningKind::ADD)
149 "Expects add combining kind.");
151 VectorType lhsTy = contractOp.getLhsType();
152 if (!lhsTy.getElementType().isBF16() &&
153 !lhsTy.getElementType().isSignlessInteger(8))
155 contractOp,
"Only BF16/Int8 lowering is supported.");
157 unsigned int blockingFactor = lhsTy.getElementType().isBF16() ? 2 : 4;
160 contractOp.getIndexingMapsArray(), blockingFactor);
162 if (lhsTy.getElementType().isSignlessInteger(8) && !isVnni)
165 VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
169 ArrayRef<int64_t> accShape = accTy.getShape();
170 llvm::SmallVector<int64_t> nonUnitDimAcc;
171 llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
172 [](int64_t dim) {
return dim != 1; });
173 if (nonUnitDimAcc.size() != 1)
175 contractOp,
"A or B should be a non-unit dim in acc.");
177 int64_t nonUnitDimValue = nonUnitDimAcc.front();
180 if (lhsTy.getElementType().isBF16() && nonUnitDimValue != 4 &&
181 nonUnitDimValue != 8 && nonUnitDimValue != 16)
183 contractOp,
"BF16 dot-product operation expects non-unit (LHR or "
184 "RHS) dim and acc dim of size 4/8/16.");
186 if (lhsTy.getElementType().isSignlessInteger(8) && nonUnitDimValue != 4 &&
187 nonUnitDimValue != 8 && nonUnitDimValue != 16 &&
188 nonUnitDimAcc.front() == nonUnitDimValue)
190 contractOp,
"Int8 dot-product operation expects non-unit (LHR or "
191 "RHS) dim and acc dim of size 4/8/16.");
193 ArrayRef<int64_t> lhsShape = lhsTy.getShape();
194 llvm::SmallVector<int64_t> nonUnitDimLhs;
195 llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
196 [](int64_t dim) {
return dim != 1; });
198 VectorType rhsTy = contractOp.getRhsType();
199 ArrayRef<int64_t> rhsShape = rhsTy.getShape();
200 llvm::SmallVector<int64_t> nonUnitDimRhs;
201 llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
202 [](int64_t dim) {
return dim != 1; });
204 if ((nonUnitDimLhs.size() - 1) > 0 && (nonUnitDimRhs.size() - 1) > 0)
206 "Excepts unit dimensions for either "
207 "LHS or RHS shape.");
209 if ((nonUnitDimLhs.size() - 1) != 1 && (nonUnitDimRhs.size() - 1) != 1)
212 "Excepts a one non-unit A/B dimension for either LHS or RHS shape.");
214 bool rhsHasMultipleNonUnitDims = (nonUnitDimRhs.size() - 1) > 0;
215 int64_t extraFlatDim = rhsHasMultipleNonUnitDims ? nonUnitDimLhs.front()
216 : nonUnitDimRhs.front();
218 if (!isVnni && (extraFlatDim != blockingFactor))
220 contractOp,
"The K or reduction dim for flat layout should be 2.");
222 if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()) ||
223 (lhsTy.getElementType().isSignlessInteger(8) &&
224 !accTy.getElementType().isSignlessInteger(32)))
226 "Only F32 for BF16 or Int32 for Int8 "
227 "accumulation type is supported.");
229 Value unitDimOperand =
230 rhsHasMultipleNonUnitDims ? contractOp.getLhs() : contractOp.getRhs();
231 Value nonUnitDimOperand =
232 rhsHasMultipleNonUnitDims ? contractOp.getRhs() : contractOp.getLhs();
237 vector::ContractionOp pairContractOp;
238 Operation *nextOp = contractOp;
239 while ((nextOp = nextOp->getNextNode())) {
240 auto contOp = dyn_cast<vector::ContractionOp>(nextOp);
246 rhsHasMultipleNonUnitDims,
248 pairContractOp = contOp;
258 if (!pairContractOp &&
259 (!isNonUnitDimOperandShuffled(nonUnitDimOperand) || accRead))
261 "Could not find a contract pair");
267 Operation *accReadOp0 =
269 Operation *accReadOp1 =
274 Operation *resultWriteOp0 =
276 Operation *resultWriteOp1 =
279 if (!accReadOp0 || !accReadOp1)
282 "Operands doesn't have load or transfer_read as it's parent op");
284 if (!resultWriteOp0 || !resultWriteOp1)
287 "The use of contract operations are neither vector.store "
288 "or transfer_write or has multiple users.");
290 if (contractOp->getBlock() == accReadOp1->
getBlock() &&
291 contractOp->isBeforeInBlock(accReadOp1))
294 "The load/read operation of pair contract operation is "
295 "after the contractOp");
297 if (pairContractOp->getBlock() == resultWriteOp0->
getBlock() &&
300 contractOp,
"The store/write operation of contract operation is "
301 "before the pair contract operation");
303 LogicalResult readShuffle =
305 pairContractOp, nonUnitDimValue, accTy);
309 contractOp,
"Accumulator read is not by transfer_read or load");
313 rewriter, resultWriteOp0, resultWriteOp1, nonUnitDimValue, accTy);
318 "Write to accumulator is not by transfer_write or store");
321 if (!isNonUnitDimOperandShuffled(nonUnitDimOperand)) {
322 Value nonUnitDimOperandPairContract = rhsHasMultipleNonUnitDims
323 ? pairContractOp.getRhs()
324 : pairContractOp.getLhs();
327 Operation *nonUnitDimReadOp =
329 Operation *nonUnitDimReadOpPairContract =
332 if (!nonUnitDimReadOp || !nonUnitDimReadOpPairContract)
334 contractOp,
"Could not find a valid contract pair");
336 VectorType nonUnitDimTy = rhsHasMultipleNonUnitDims
337 ? contractOp.getRhsType()
338 : contractOp.getLhsType();
340 packNonUnitDimOperandToVNNI(
341 rewriter, nonUnitDimReadOp, nonUnitDimReadOpPairContract,
342 contractOp, pairContractOp, blockingFactor * nonUnitDimValue,
345 nonUnitDimOperand = rhsHasMultipleNonUnitDims ? contractOp.getRhs()
346 : contractOp.getLhs();
351 auto loc = contractOp.getLoc();
352 auto castAcc = vector::ShapeCastOp::create(
354 VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
355 contractOp.getAcc());
357 VectorType nonUnitDimTy = rhsHasMultipleNonUnitDims
358 ? contractOp.getRhsType()
359 : contractOp.getLhsType();
360 VectorType unitDimTy = rhsHasMultipleNonUnitDims ? contractOp.getLhsType()
361 : contractOp.getRhsType();
365 auto castNonUnitDim = vector::ShapeCastOp::create(
367 VectorType::get(blockingFactor * nonUnitDimValue,
368 nonUnitDimTy.getElementType()),
371 auto castUnitDim = vector::ShapeCastOp::create(
373 VectorType::get(blockingFactor, unitDimTy.getElementType()),
375 auto bitcastUnitDim = vector::BitCastOp::create(
378 auto broadcastUnitDim = vector::BroadcastOp::create(
382 auto bitcastUnitDimPkType = vector::BitCastOp::create(
383 rewriter, loc, castNonUnitDim.getResult().getType(), broadcastUnitDim);
385 if (lhsTy.getElementType().isBF16()) {
386 dp = x86::avx512::DotBF16Op::create(
388 VectorType::get(nonUnitDimValue, rewriter.
getF32Type()), castAcc,
389 bitcastUnitDimPkType, castNonUnitDim);
392 if (lhsTy.getElementType().isSignlessInteger(8)) {
393 if (nonUnitDimAcc.front() == 16) {
394 dp = x86::avx10::AVX10DotInt8Op::create(
397 castAcc, bitcastUnitDimPkType, castNonUnitDim);
399 dp = x86::avx::DotInt8Op::create(
402 castAcc, bitcastUnitDimPkType, castNonUnitDim);
409 auto castDp = vector::ShapeCastOp::create(rewriter, loc, accTy, dp);
419 patterns.
add<VectorContractToPackedTypeDotProduct>(patterns.
getContext());
IntegerType getIntegerType(unsigned width)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Block * getBlock()
Returns the operation block that contains this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void moveOpAfter(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right after existingOp which may be in the...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
LogicalResult shuffleBeforeWriteLikeOp(PatternRewriter &rewriter, Operation *opA, Operation *opB, int64_t nonUnitDimAcc, VectorType accTy)
Operation * traceToVectorWriteLikeUserOperation(Value v)
void populateVectorContractToPackedTypeDotProductPatterns(RewritePatternSet &patterns)
bool isInVnniLayout(Operation *op, llvm::ArrayRef< AffineMap > indexingMaps, std::optional< unsigned > blockingFactor=std::nullopt)
Operation * traceToVectorReadLikeParentOperation(Value v)
LogicalResult shuffleAfterReadLikeOp(PatternRewriter &rewriter, Operation *opA, Operation *opB, vector::ContractionOp contractA, vector::ContractionOp contractB, int64_t nonUnitDimAcc, VectorType accTy)
bool validatePairVectorContract(vector::ContractionOp contractOp, vector::ContractionOp pairContOp, bool rhsHasMultipleNonUnitDims, int64_t nonUnitDimValue)
Include the generated interface declarations.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...