MLIR 22.0.0git
VectorContractBF16ToFMA.cpp
Go to the documentation of this file.
1//===- VectorContractBF16ToFMA.cpp-----------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
16
18#include "mlir/IR/Dominance.h"
20
21#include "mlir/Pass/Pass.h"
23#include "llvm/Support/Casting.h"
24
25using namespace mlir;
26using namespace mlir::vector;
27using namespace mlir::x86vector;
28
29// Verifies that the LHS and RHS operands of a vector.contract are load or
30// vector.transfer_read operations on a memref source buffer, and checks
31// their bounds, dimensions, offsets, and strides.
33 Operation *defOp = prodOp.getDefiningOp();
34 if (!defOp)
35 return false;
36
37 if (auto readOp = prodOp.getDefiningOp<mlir::vector::TransferReadOp>()) {
38 if (readOp.hasOutOfBoundsDim())
39 return false;
40
41 if (!readOp.getPermutationMap().isMinorIdentity())
42 return false;
43 }
44
45 Value srcBuff;
47 llvm::TypeSwitch<Operation *>(defOp).Case<TransferReadOp, LoadOp>(
48 [&](auto readOp) {
49 srcBuff = readOp.getOperand(0);
50 indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
51 readOp.getIndices().end());
52 });
53
54 if (!srcBuff)
55 return false;
56
57 // Return false, if the source is not a memref type
58 Type srcType = srcBuff.getType();
59 if (!llvm::isa<MemRefType>(srcType))
60 return false;
61
62 // Return false if the two innermost strides of the memref are not contiguous.
63 // The x86vector.avx.cvt.packed.even/odd.indexed_to_f32 operations require
64 // an eight-element tuple of bf16 values to be contiguous.
65 if (!llvm::cast<mlir::MemRefType>(srcType).areTrailingDimsContiguous(2))
66 return false;
67
68 // Return false if the vnni offset of load or transfer_read is not zero.
69 if (getConstantIntValue(indexVals.back()) != 0)
70 return false;
71
72 return true;
73}
74
75// This function retrieves the source operation of the load or transfer
76// reads and creates subviews for the BF16 packed-operations to
77// broadcast or load BF16 elements as F32 packed elements.
78//
79// Example(1) Unit Dim:
80// ```
81// vector.load %arg0[%c0, %c0, %c0]:memref<4x1x2xbf16>,vector<1x1x2xbf16>
82// ```
83// to
84// ```
85// memref.subview %arg0[%c0,%c0,%c1]:memref<4x1x2xbf16> to memref<1x1x1xbf16>
86// memref.subview %arg0[%c0,%c0,%c0]:memref<4x1x2xbf16> to memref<1x1x1xbf16>
87// ```
88//
89// Example(2) Non-unit Dim:
90// ```
91// vector.load %arg1[%c0, %c0, %c0]:memref<1x32x2xbf16>,vector<1x8x2xbf16>
92// ```
93// to
94// ```
95// memref.subview %arg1[%c0,%c0,%c0]:memref<1x32x2xbf16> to memref<1x8x2xbf16>
96// ```
99 ArrayRef<int64_t> nonUnitDimShape, bool isUnitDim) {
100
101 Operation *defOp = prodOp.getDefiningOp();
102
103 Value srcBuff;
105 llvm::TypeSwitch<Operation *>(defOp).Case<TransferReadOp, LoadOp>(
106 [&](auto readOp) {
107 srcBuff = readOp.getOperand(0);
108 indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
109 readOp.getIndices().end());
110 });
111
112 int64_t mnDimSize = 1;
113 unsigned mnDimIdx = 0;
114
115 if (!isUnitDim) {
116 for (auto it : llvm::enumerate(nonUnitDimShape)) {
117 if (it.value() != 1) {
118 mnDimSize = it.value();
119 mnDimIdx = it.index();
120 break;
121 }
122 }
123 }
124
125 int vnniDimSize = isUnitDim ? 1 : 2;
126
127 auto nonVNNIDimSize = indexVals.size() - 1;
128 // Create the size and stride offsets.
129 auto one = rewriter.getIndexAttr(1);
130 SmallVector<OpFoldResult> strides(indexVals.size(), one);
131 SmallVector<OpFoldResult> sizes(nonVNNIDimSize, one);
132
133 sizes.push_back(rewriter.getIndexAttr(vnniDimSize));
134
135 // update the unit/nonUnit Dim size either it is A(LHS) or B(RHS).
136 sizes[mnDimIdx] = rewriter.getIndexAttr(mnDimSize);
137
138 // for unitDim, first broadcast odd element, so index is set to 1.
139 if (isUnitDim)
140 indexVals[indexVals.size() - 1] = rewriter.getIndexAttr(1);
141
143 auto subview = memref::SubViewOp::create(rewriter, loc, srcBuff, indexVals,
144 sizes, strides);
145 subviews.push_back(subview);
146
147 // For unit-dims, two subviews should be created for the odd and even
148 // element in the VNNI tuple (2xbf16) because x86vector.avx.bcst_to_f32.packed
149 // op loads and broadcast the first BF16 element into packed F32. It
150 // cannot distinguish between even and odd BF16 elements within a
151 // packed pair.
152 //
153 // Example:
154 // memref.subview %arg0[%c0,%c1]:memref<1x2xbf16> to memref<1x1xbf16> // Odd
155 // memref.subview %arg0[%c0,%c0]:memref<1x2xbf16> to memref<1x1xbf16> // Even
156 if (mnDimSize == 1) {
157 indexVals[indexVals.size() - 1] = rewriter.getIndexAttr(0);
158 sizes[indexVals.size() - 1] = rewriter.getIndexAttr(1);
159
160 auto unitDimEvenIdxSubview = memref::SubViewOp::create(
161 rewriter, loc, srcBuff, indexVals, sizes, strides);
162 subviews.push_back(unitDimEvenIdxSubview);
163 }
164
165 return subviews;
166}
167
168// Implements outer product contraction as a sequence of BF16-packed
169// operation even/odd loads and FMA operations.
170//
171// For example:
172// ```
173// %1 = vector.load from memref (%m1) -> vector<1x1x2xbf16>
174// %2 = vector.load from memref (%m2) -> vector<1x8x2xbf16>
175// return vector.contract %1, %2, %arg1
176// ```
177// to
178// ```
179// %1 = x86vector.avx.bcst_to_f32.packed %m1[c1] -> vector<8xf32>
180// %2 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %m2 -> vector<8xf32>
181// %3 = vector.fma %1, %2, %arg1
182// %4 = x86vector.avx.bcst_to_f32.packed %m1[c0] -> vector<8xf32>
183// %5 = x86vector.avx.cvt.packed.even.indexed_to_f32 %m2 -> vector<8xf32>
184// return vector.fma %4, %5, %3
185// ```
187 : public OpRewritePattern<vector::ContractionOp> {
188 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
189
190 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
191 PatternRewriter &rewriter) const override {
192
193 if (contractOp.getKind() != vector::CombiningKind::ADD)
194 return rewriter.notifyMatchFailure(contractOp,
195 "Expects add combining kind.");
196
197 // TODO: Move this validation to a common utility folder. Planned to
198 // do once (code refactoring), all architecture specific nanokernel
199 // passes are merged into the repo.
200 VectorType lhsTy = contractOp.getLhsType();
201 if (!lhsTy.getElementType().isBF16())
202 return rewriter.notifyMatchFailure(contractOp,
203 "Only BF16 lowering is supported.");
204
205 if (!isInVnniLayout(contractOp.getOperation(),
206 contractOp.getIndexingMapsArray(),
207 /*blockingFactor=*/2))
208 return rewriter.notifyMatchFailure(contractOp,
209 "Input matrices not in VNNI format.");
210
211 VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
212 if (!accTy)
213 return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
214
215 if (!accTy.getElementType().isF32())
216 return rewriter.notifyMatchFailure(
217 contractOp, "Only F32 acumulation supported for BF16 type.");
218
219 ArrayRef<int64_t> lhsShape = lhsTy.getShape();
220 llvm::SmallVector<int64_t> nonUnitDimLhs;
221 llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
222 [](int64_t dim) { return dim != 1; });
223
224 VectorType rhsTy = contractOp.getRhsType();
225 ArrayRef<int64_t> rhsShape = rhsTy.getShape();
226 llvm::SmallVector<int64_t> nonUnitDimRhs;
227 llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
228 [](int64_t dim) { return dim != 1; });
229
230 if ((nonUnitDimLhs.size() - 1) > 0 && (nonUnitDimRhs.size() - 1) > 0)
231 return rewriter.notifyMatchFailure(contractOp,
232 "Excepts unit dimensions for either "
233 "LHS or RHS shape other than VNNI.");
234
235 if ((nonUnitDimLhs.size() - 1) != 1 && (nonUnitDimRhs.size() - 1) != 1)
236 return rewriter.notifyMatchFailure(
237 contractOp,
238 "Excepts a one non-unit A/B dimension for either LHS or RHS shape.");
239
240 ArrayRef<int64_t> accShape = accTy.getShape();
241 llvm::SmallVector<int64_t> nonUnitDimAcc;
242 llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
243 [](int64_t dim) { return dim != 1; });
244 if (nonUnitDimAcc.size() != 1)
245 return rewriter.notifyMatchFailure(
246 contractOp, "A or B should be a non-unit dim in acc.");
247
248 // Non-unit dimensions should match the vector length of BF16.
249 unsigned int nonUnitDim = nonUnitDimLhs.size() == 2 ? nonUnitDimLhs.front()
250 : nonUnitDimRhs.front();
251 if (nonUnitDim != 4 && nonUnitDim != 8 &&
252 !(nonUnitDimAcc.front() == nonUnitDim))
253 return rewriter.notifyMatchFailure(
254 contractOp, "BF16 packed load operation expects non-unit (LHR or "
255 "RHS) dim and acc dim of size 4/8.");
256
257 if (!validateVectorContractOperands(contractOp.getLhs()) ||
258 !validateVectorContractOperands(contractOp.getRhs())) {
259 return rewriter.notifyMatchFailure(
260 contractOp, "The LHS or RHS is in an invalid format. Either it has "
261 "false in-bounds, "
262 "a non-identity permutation map, a non-zero VNNI offset, "
263 "a non-memref "
264 "source, or a non-unit VNNI stride");
265 }
266
267 // Lower vector.contract to FMAs with help of BF16 packed ops.
268 auto loc = contractOp.getLoc();
269
270 // create the unit-dimension LHS or RHS subview and the
271 // corresponding non-unit dimension LHS or RHS subview on the other-side.
272 // For example, if LHS has type vector<1x1x2xbf16> and RHS has type
273 // vector<1x8x2xbf16>, we create two subview for the LHS and one subview
274 // for the RHS. In the opposite case (non-unit dimension on the LHS), we
275 // do vice-versa.
276 bool rhsHasMultipleNonUnitDims = (nonUnitDimRhs.size() - 1) > 0;
277 // Select which operand is "unit" and which is "non-unit".
278 Value unitSrc =
279 rhsHasMultipleNonUnitDims ? contractOp.getLhs() : contractOp.getRhs();
280 Value nonUnitSrc =
281 rhsHasMultipleNonUnitDims ? contractOp.getRhs() : contractOp.getLhs();
282
283 ArrayRef<int64_t> nonUnitDimShape =
284 rhsHasMultipleNonUnitDims ? rhsShape : lhsShape;
285
286 // Build subviews.
287 auto unitDimSubview = getSubviewFromVectorInput(loc, rewriter, unitSrc,
288 nonUnitDimShape, true);
289
290 auto nonUnitDimSubview = getSubviewFromVectorInput(
291 loc, rewriter, nonUnitSrc, nonUnitDimShape, false);
292
293 auto castAcc = vector::ShapeCastOp::create(
294 rewriter, loc,
295 VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
296 contractOp.getAcc());
297 VectorType dstType =
298 VectorType::get(nonUnitDimAcc.front(), rewriter.getF32Type());
299
300 // Load, broadcast, and do FMA for odd indexed BF16 elements.
301 auto loadBcstOddIdxElementToF32 = x86vector::BcstToPackedF32Op::create(
302 rewriter, loc, dstType, unitDimSubview[0]);
303 auto loadOddIdxElementF32 = x86vector::CvtPackedOddIndexedToF32Op::create(
304 rewriter, loc, dstType, nonUnitDimSubview[0]);
305 auto oddIdxFMA =
306 vector::FMAOp::create(rewriter, loc, loadBcstOddIdxElementToF32,
307 loadOddIdxElementF32, castAcc);
308
309 // Load, broadcast, and do FMA for even indexed BF16 elements.
310 auto loadBcstEvenIdxElementToF32 = x86vector::BcstToPackedF32Op::create(
311 rewriter, loc, dstType, unitDimSubview[1]);
312 auto loadEvenIdxElementF32 = x86vector::CvtPackedEvenIndexedToF32Op::create(
313 rewriter, loc, dstType, nonUnitDimSubview[0]);
314 vector::FMAOp fma =
315 vector::FMAOp::create(rewriter, loc, loadBcstEvenIdxElementToF32,
316 loadEvenIdxElementF32, oddIdxFMA);
317
318 auto castFma = vector::ShapeCastOp::create(rewriter, loc, accTy, fma);
319 rewriter.replaceOp(contractOp, castFma);
320 return success();
321 }
322};
323
return success()
static SmallVector< memref::SubViewOp > getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp, ArrayRef< int64_t > nonUnitDimShape, bool isUnitDim)
static bool validateVectorContractOperands(Value prodOp)
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
FloatType getF32Type()
Definition Builders.cpp:43
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
void populateVectorContractBF16ToFMAPatterns(RewritePatternSet &patterns)
bool isInVnniLayout(Operation *op, llvm::ArrayRef< AffineMap > indexingMaps, std::optional< unsigned > blockingFactor=std::nullopt)
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
const FrozenRewritePatternSet & patterns
LogicalResult matchAndRewrite(vector::ContractionOp contractOp, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})