193 if (contractOp.getKind() != vector::CombiningKind::ADD)
195 "Expects add combining kind.");
200 VectorType lhsTy = contractOp.getLhsType();
201 if (!lhsTy.getElementType().isBF16())
203 "Only BF16 lowering is supported.");
206 contractOp.getIndexingMapsArray(),
209 "Input matrices not in VNNI format.");
211 VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
215 if (!accTy.getElementType().isF32())
217 contractOp,
"Only F32 acumulation supported for BF16 type.");
221 llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
222 [](
int64_t dim) {
return dim != 1; });
224 VectorType rhsTy = contractOp.getRhsType();
227 llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
228 [](
int64_t dim) {
return dim != 1; });
230 if ((nonUnitDimLhs.size() - 1) > 0 && (nonUnitDimRhs.size() - 1) > 0)
232 "Excepts unit dimensions for either "
233 "LHS or RHS shape other than VNNI.");
235 if ((nonUnitDimLhs.size() - 1) != 1 && (nonUnitDimRhs.size() - 1) != 1)
238 "Excepts a one non-unit A/B dimension for either LHS or RHS shape.");
242 llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
243 [](
int64_t dim) {
return dim != 1; });
244 if (nonUnitDimAcc.size() != 1)
246 contractOp,
"A or B should be a non-unit dim in acc.");
249 unsigned int nonUnitDim = nonUnitDimLhs.size() == 2 ? nonUnitDimLhs.front()
250 : nonUnitDimRhs.front();
251 if (nonUnitDim != 4 && nonUnitDim != 8 &&
252 !(nonUnitDimAcc.front() == nonUnitDim))
254 contractOp,
"BF16 packed load operation expects non-unit (LHR or "
255 "RHS) dim and acc dim of size 4/8.");
260 contractOp,
"The LHS or RHS is in an invalid format. Either it has "
262 "a non-identity permutation map, a non-zero VNNI offset, "
264 "source, or a non-unit VNNI stride");
268 auto loc = contractOp.getLoc();
276 bool rhsHasMultipleNonUnitDims = (nonUnitDimRhs.size() - 1) > 0;
279 rhsHasMultipleNonUnitDims ? contractOp.getLhs() : contractOp.getRhs();
281 rhsHasMultipleNonUnitDims ? contractOp.getRhs() : contractOp.getLhs();
284 rhsHasMultipleNonUnitDims ? rhsShape : lhsShape;
288 nonUnitDimShape,
true);
291 loc, rewriter, nonUnitSrc, nonUnitDimShape,
false);
293 auto castAcc = vector::ShapeCastOp::create(
295 VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
296 contractOp.getAcc());
298 VectorType::get(nonUnitDimAcc.front(), rewriter.
getF32Type());
301 auto loadBcstOddIdxElementToF32 = x86vector::BcstToPackedF32Op::create(
302 rewriter, loc, dstType, unitDimSubview[0]);
303 auto loadOddIdxElementF32 = x86vector::CvtPackedOddIndexedToF32Op::create(
304 rewriter, loc, dstType, nonUnitDimSubview[0]);
306 vector::FMAOp::create(rewriter, loc, loadBcstOddIdxElementToF32,
307 loadOddIdxElementF32, castAcc);
310 auto loadBcstEvenIdxElementToF32 = x86vector::BcstToPackedF32Op::create(
311 rewriter, loc, dstType, unitDimSubview[1]);
312 auto loadEvenIdxElementF32 = x86vector::CvtPackedEvenIndexedToF32Op::create(
313 rewriter, loc, dstType, nonUnitDimSubview[0]);
315 vector::FMAOp::create(rewriter, loc, loadBcstEvenIdxElementToF32,
316 loadEvenIdxElementF32, oddIdxFMA);
318 auto castFma = vector::ShapeCastOp::create(rewriter, loc, accTy, fma);