29static FailureOr<SmallVector<mlir::utils::IteratorType>>
34 map.
getNumDims(), mlir::utils::IteratorType::reduction);
36 if (
auto dim = dyn_cast<AffineDimExpr>(expr))
37 iterators[dim.getPosition()] = mlir::utils::IteratorType::parallel;
44 std::optional<unsigned> blockingFactor) {
46 FailureOr<linalg::ContractionDimensions> dims =
53 auto typeA = dyn_cast<ShapedType>(matA.getType());
54 auto typeB = dyn_cast<ShapedType>(matB.getType());
55 unsigned rankA = typeA.getRank();
56 unsigned rankB = typeB.getRank();
58 if (rankA < 3 || rankB < 3)
63 if (dims->k.size() < 2)
78 auto vnniDimA = dyn_cast<AffineDimExpr>(mapA.
getResult(rankA - 1));
79 auto vnniDimB = dyn_cast<AffineDimExpr>(mapB.
getResult(rankB - 1));
80 if (!vnniDimA || !vnniDimB || vnniDimA != vnniDimB ||
81 iteratorTypes[vnniDimA.getPosition()] !=
82 mlir::utils::IteratorType::reduction)
84 auto redDimA = dyn_cast<AffineDimExpr>(mapA.
getResult(rankA - 2));
85 auto redDimB = dyn_cast<AffineDimExpr>(mapB.
getResult(rankB - 3));
86 if (!redDimA || !redDimB || redDimA != redDimB ||
87 iteratorTypes[redDimA.getPosition()] !=
88 mlir::utils::IteratorType::reduction)
90 auto parallelDimB = dyn_cast<AffineDimExpr>(mapB.
getResult(rankB - 2));
91 if (!parallelDimB || iteratorTypes[parallelDimB.getPosition()] !=
92 mlir::utils::IteratorType::parallel)
99 auto vnniDimSize = typeB.getShape().back();
100 if (vnniDimSize == ShapedType::kDynamic || vnniDimSize == 0 ||
101 vnniDimSize % 2 != 0)
103 if (typeA.getShape().back() != vnniDimSize)
105 if (blockingFactor && vnniDimSize != *blockingFactor)
109 if (typeA.getShape().end()[-2] != typeB.getShape().end()[-3])
127struct VectorContractToPackedTypeDotProduct
129 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
131 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
132 PatternRewriter &rewriter)
const override {
134 if (contractOp.getKind() != vector::CombiningKind::ADD)
136 "Expects add combining kind.");
138 VectorType lhsTy = contractOp.getLhsType();
139 if (!lhsTy.getElementType().isBF16() &&
140 !lhsTy.getElementType().isSignlessInteger(8))
142 contractOp,
"Only BF16/Int8 lowering is supported.");
144 unsigned int blockingFactor = lhsTy.getElementType().isBF16() ? 2 : 4;
145 if (!isInVnniLayout(contractOp.getOperation(),
146 contractOp.getIndexingMapsArray(), blockingFactor))
148 "Input matrices not in VNNI format.");
150 ArrayRef<int64_t> lhsShape = lhsTy.getShape();
151 llvm::SmallVector<int64_t> nonUnitDimLhs;
152 llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
153 [](int64_t dim) {
return dim != 1; });
155 VectorType rhsTy = contractOp.getRhsType();
156 ArrayRef<int64_t> rhsShape = rhsTy.getShape();
157 llvm::SmallVector<int64_t> nonUnitDimRhs;
158 llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
159 [](int64_t dim) {
return dim != 1; });
161 if ((nonUnitDimLhs.size() - 1) > 0 && (nonUnitDimRhs.size() - 1) > 0)
163 "Excepts unit dimensions for either "
164 "LHS or RHS shape other than VNNI.");
166 if ((nonUnitDimLhs.size() - 1) != 1 && (nonUnitDimRhs.size() - 1) != 1)
169 "Excepts a one non-unit A/B dimension for either LHS or RHS shape.");
171 VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
175 if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()) ||
176 (lhsTy.getElementType().isSignlessInteger(8) &&
177 !accTy.getElementType().isSignlessInteger(32)))
179 "Only F32 for BF16 or Int32 for Int8 "
180 "accumulation type is supported.");
182 ArrayRef<int64_t> accShape = accTy.getShape();
183 llvm::SmallVector<int64_t> nonUnitDimAcc;
184 llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
185 [](int64_t dim) {
return dim != 1; });
186 if (nonUnitDimAcc.size() != 1)
188 contractOp,
"A or B should be a non-unit dim in acc.");
192 unsigned int nonUnitDim = nonUnitDimLhs.size() == 2 ? nonUnitDimLhs.front()
193 : nonUnitDimRhs.front();
194 if (lhsTy.getElementType().isBF16() && nonUnitDim != 4 && nonUnitDim != 8 &&
195 nonUnitDim != 16 && nonUnitDimAcc.front() == nonUnitDim)
197 contractOp,
"BF16 dot-product operation expects non-unit (LHR or "
198 "RHS) dim and acc dim of size 4/8/16.");
200 if (lhsTy.getElementType().isSignlessInteger(8) && nonUnitDim != 4 &&
201 nonUnitDim != 8 && nonUnitDimAcc.front() == nonUnitDim)
203 contractOp,
"Int8 dot-product operation expects non-unit (LHR or "
204 "RHS) dim and acc dim of size 4/8.");
206 auto loc = contractOp.getLoc();
207 auto castAcc = vector::ShapeCastOp::create(
209 VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
210 contractOp.getAcc());
219 if ((nonUnitDimRhs.size() - 1) > 0) {
220 auto castRhs = vector::ShapeCastOp::create(
222 VectorType::get(nonUnitDimRhs.front() * nonUnitDimRhs.back(),
223 rhsTy.getElementType()),
224 contractOp.getRhs());
225 auto castLhs = vector::ShapeCastOp::create(
227 VectorType::get(nonUnitDimLhs.front(), lhsTy.getElementType()),
228 contractOp.getLhs());
229 auto bitcastLhs = vector::BitCastOp::create(
232 auto broadcastLhs = vector::BroadcastOp::create(
234 VectorType::get({nonUnitDimRhs.front()}, rewriter.
getIntegerType(32)),
236 auto bitcastLhsPkType = vector::BitCastOp::create(
237 rewriter, loc, castRhs.getResult().getType(), broadcastLhs);
239 if (lhsTy.getElementType().isBF16()) {
240 dp = x86vector::DotBF16Op::create(
242 VectorType::get(nonUnitDimRhs.front(), rewriter.
getF32Type()),
243 castAcc, bitcastLhsPkType, castRhs);
246 if (lhsTy.getElementType().isSignlessInteger(8)) {
247 dp = x86vector::DotInt8Op::create(
249 VectorType::get(nonUnitDimRhs.front(), rewriter.
getIntegerType(32)),
250 castAcc, bitcastLhsPkType, castRhs);
253 auto castLhs = vector::ShapeCastOp::create(
255 VectorType::get(nonUnitDimLhs.front() * nonUnitDimLhs.back(),
256 lhsTy.getElementType()),
257 contractOp.getLhs());
258 auto castRhs = vector::ShapeCastOp::create(
260 VectorType::get(nonUnitDimRhs.front(), rhsTy.getElementType()),
261 contractOp.getRhs());
262 auto bitcastRhs = vector::BitCastOp::create(
265 auto broadcastRhs = vector::BroadcastOp::create(
267 VectorType::get({nonUnitDimLhs.front()}, rewriter.
getIntegerType(32)),
269 auto bitcastRhsPkType = vector::BitCastOp::create(
270 rewriter, loc, castLhs.getResult().getType(), broadcastRhs);
272 if (lhsTy.getElementType().isBF16()) {
273 dp = x86vector::DotBF16Op::create(
275 VectorType::get(nonUnitDimLhs.front(), rewriter.
getF32Type()),
276 castAcc, castLhs, bitcastRhsPkType);
279 if (lhsTy.getElementType().isSignlessInteger(8)) {
280 dp = x86vector::DotInt8Op::create(
282 VectorType::get(nonUnitDimLhs.front(), rewriter.
getIntegerType(32)),
283 castAcc, castLhs, bitcastRhsPkType);
290 auto castDp = vector::ShapeCastOp::create(rewriter, loc, accTy, dp);
static FailureOr< SmallVector< utils::IteratorType > > inferIteratorsFromOutMap(AffineMap map)
Infer the iterator types from the init affine map.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
AffineExpr getResult(unsigned idx) const
IntegerType getIntegerType(unsigned width)
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
void populateVectorContractToPackedTypeDotProductPatterns(RewritePatternSet &patterns)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...