MLIR 23.0.0git
X86VectorUtils.cpp
Go to the documentation of this file.
1//===- X86VectorUtils.cpp - MLIR Utilities for X86VectorOps -------------===//
2//
3// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
17#include "mlir/IR/Types.h"
18
19#include "llvm/ADT/TypeSwitch.h"
20#include "llvm/Support/Casting.h"
21
22#include "llvm/ADT/ArrayRef.h"
23#include <cassert>
24
25namespace mlir {
26namespace x86vector {
27
28static FailureOr<SmallVector<mlir::utils::IteratorType>>
30 if (!map.isProjectedPermutation())
31 return failure();
33 map.getNumDims(), mlir::utils::IteratorType::reduction);
34 for (auto expr : map.getResults())
35 if (auto dim = dyn_cast<AffineDimExpr>(expr))
36 iterators[dim.getPosition()] = mlir::utils::IteratorType::parallel;
37 return iterators;
38}
39
40// Returns true if the operation is in VNNI layout.
41// Optionally, the check can be constrained to a specific VNNI blocking factor.
43 std::optional<unsigned> blockingFactor) {
44 // Narrow down type operations - VNNI only applies to contractions.
45 FailureOr<linalg::ContractionDimensions> dims =
46 linalg::inferContractionDims(indexingMaps);
47 if (failed(dims))
48 return false;
49
50 auto matA = op->getOperand(0);
51 auto matB = op->getOperand(1);
52 auto typeA = dyn_cast<ShapedType>(matA.getType());
53 auto typeB = dyn_cast<ShapedType>(matB.getType());
54 unsigned rankA = typeA.getRank();
55 unsigned rankB = typeB.getRank();
56 // VNNI format requires at least 1 parallel and 2 reduction dimensions.
57 if (rankA < 3 || rankB < 3)
58 return false;
59
60 // At least two reduction dimensions are expected:
61 // one for the VNNI factor and one for the K dimension
62 if (dims->k.size() < 2)
63 return false;
64
65 // Validate affine maps - VNNI computation should be defined by the two
66 // innermost reduction iterators.
67 // The input matrix dimensions layout must match the following:
68 // - matrix A - [...][K/vnniFactor][vnniFactor]
69 // - matrix B - [...][K/vnniFactor][N][vnniFactor]
70 auto maybeIters = inferIteratorsFromOutMap(indexingMaps[2] /* outs */);
71 if (failed(maybeIters))
72 return false;
73 SmallVector<mlir::utils::IteratorType> iteratorTypes = *maybeIters;
74 AffineMap mapA = indexingMaps[0];
75 AffineMap mapB = indexingMaps[1];
76
77 auto vnniDimA = dyn_cast<AffineDimExpr>(mapA.getResult(rankA - 1));
78 auto vnniDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 1));
79 if (!vnniDimA || !vnniDimB || vnniDimA != vnniDimB ||
80 iteratorTypes[vnniDimA.getPosition()] !=
81 mlir::utils::IteratorType::reduction)
82 return false;
83 auto redDimA = dyn_cast<AffineDimExpr>(mapA.getResult(rankA - 2));
84 auto redDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 3));
85 if (!redDimA || !redDimB || redDimA != redDimB ||
86 iteratorTypes[redDimA.getPosition()] !=
87 mlir::utils::IteratorType::reduction)
88 return false;
89 auto parallelDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 2));
90 if (!parallelDimB || iteratorTypes[parallelDimB.getPosition()] !=
91 mlir::utils::IteratorType::parallel)
92 return false;
93
94 // VNNI factor must be:
95 // - the innermost inputs' dimension
96 // - statically known
97 // - multiple of 2 or equal to the specified factor
98 auto vnniDimSize = typeB.getShape().back();
99 if (vnniDimSize == ShapedType::kDynamic || vnniDimSize == 0 ||
100 vnniDimSize % 2 != 0)
101 return false;
102 if (typeA.getShape().back() != vnniDimSize)
103 return false;
104 if (blockingFactor && vnniDimSize != *blockingFactor)
105 return false;
106
107 // The split reduction dimension size should also match.
108 if (typeA.getShape().end()[-2] != typeB.getShape().end()[-3])
109 return false;
110
111 return true;
112}
113
118
119inline ShuffleMasks getShuffleMasks(int64_t nonUnitDimAcc) {
120 // We only support these two layouts for now.
121 assert((nonUnitDimAcc == 8 || nonUnitDimAcc == 16) &&
122 "Unsupported nonUnitDimAcc value");
123 // Do interleaving between two <8xf32> targeting AVX2.
124 static constexpr int64_t maskLo8[] = {0, 8, 1, 9, 2, 10, 3, 11};
125 static constexpr int64_t maskHi8[] = {4, 12, 5, 13, 6, 14, 7, 15};
126
127 // Shuffle two <16xf32> as below targeting AVX512.
128 static constexpr int64_t maskLo16[] = {0, 1, 2, 3, 16, 17, 18, 19,
129 4, 5, 6, 7, 20, 21, 22, 23};
130 static constexpr int64_t maskHi16[] = {8, 9, 10, 11, 24, 25, 26, 27,
131 12, 13, 14, 15, 28, 29, 30, 31};
132
133 if (nonUnitDimAcc == 16)
134 return {maskLo16, maskHi16};
135
136 return {maskLo8, maskHi8};
137}
138
139// This function walks backward from a value to locate its originating
140// vector read-like operation (`vector.transfer_read` or `vector.load`).
141// It follows simple forwarding through unary ops and across `scf.for`
142// loop iter-arguments, while stopping if layout-transforming ops such
143// as `shape_cast` or `shuffle` are encountered. The traversal returns
144// the read-like defining operation or `nullptr` if no valid source
145// is found.
147 while (true) {
148 // Case 1: Value defined by an operation
149 if (Operation *defOp = v.getDefiningOp()) {
150 if (isa<vector::TransferReadOp, vector::LoadOp>(defOp))
151 return defOp;
152
153 return nullptr;
154 }
155
156 // Case 2: BlockArgument (scf.for iter_arg)
157 if (auto barg = dyn_cast<BlockArgument>(v)) {
158 auto *parentOp = barg.getOwner()->getParentOp();
159
160 if (auto forOp = dyn_cast<scf::ForOp>(parentOp)) {
161 unsigned argNum = barg.getArgNumber();
162
163 // arg0 = induction variable (not an iter_arg)
164 if (argNum == 0)
165 return nullptr;
166
167 unsigned iterIdx = argNum - 1;
168 v = forOp.getInitArgs()[iterIdx];
169 continue;
170 }
171
172 return nullptr;
173 }
174
175 return nullptr;
176 }
177}
178
179// This function recursively traces a value through its uses to find
180// a downstream vector write-like operation (`vector.transfer_write`
181// or `vector.store`). It transparently follows values across `scf.for`
182// and `scf.yield` boundaries while stopping if layout-altering ops such
183// as `shape_cast` or `shuffle` are encountered. The traversal returns
184// the matching write-like user. Returns `nullptr` if none is found or
185// the value has multiple users.
187
188 if (v.getNumUses() > 1)
189 return nullptr;
190
191 for (OpOperand &use : v.getUses()) {
192 Operation *user = use.getOwner();
193
194 // --- TERMINAL OPS ---
195 if (isa<vector::TransferWriteOp>(user) || isa<vector::StoreOp>(user))
196 return user;
197
198 if (isa<vector::ShapeCastOp, vector::ShuffleOp>(user))
199 return nullptr;
200
201 // --- SCF YIELD ---
202 if (auto yield = dyn_cast<scf::YieldOp>(user)) {
203 Operation *parent = yield->getParentOp();
204 unsigned idx = use.getOperandNumber();
205 if (auto *res =
207 return res;
208 continue;
209 }
210
211 // --- SCF FOR ---
212 if (auto forOp = dyn_cast<scf::ForOp>(user)) {
213 unsigned idx = use.getOperandNumber();
214 if (auto *res = traceToVectorWriteLikeUserOperation(forOp.getResult(idx)))
215 return res;
216 continue;
217 }
218
219 // --- GENERIC CASE ---
220 for (Value res : user->getResults()) {
221 if (auto *found = traceToVectorWriteLikeUserOperation(res))
222 return found;
223 }
224 }
225
226 return nullptr;
227}
228
229// This function packs the accumulator of two flat BF16 vector.contract
230// operations into VNNI packed and are then replaced in their respective
231// contraction ops, enabling post-read layout or packing transformations.
232// TODO: replace all use with the packed value along with contration
233// and for op.
235 Operation *opB,
236 vector::ContractionOp contractA,
237 vector::ContractionOp contractB,
238 int64_t nonUnitDimAcc, VectorType accTy) {
239
240 if (!isa<vector::TransferReadOp, vector::LoadOp>(opA) ||
241 !isa<vector::TransferReadOp, vector::LoadOp>(opB)) {
242 return failure();
243 }
244
245 Operation *insertAfter = opA->isBeforeInBlock(opB) ? opB : opA;
246
247 rewriter.setInsertionPointAfter(insertAfter);
248 Location loc = insertAfter->getLoc();
249
250 auto elemTy = accTy.getElementType();
251 auto flatTy = VectorType::get(nonUnitDimAcc, elemTy);
252
253 auto castA =
254 vector::ShapeCastOp::create(rewriter, loc, flatTy, opA->getResult(0));
255 auto castB =
256 vector::ShapeCastOp::create(rewriter, loc, flatTy, opB->getResult(0));
257
258 auto masks = getShuffleMasks(nonUnitDimAcc);
259
260 auto shuffleLo = vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
261 castB, masks.maskLo);
262 auto shuffleHi = vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
263 castB, masks.maskHi);
264
265 auto newAccA = vector::ShapeCastOp::create(rewriter, loc, accTy, shuffleLo);
266 auto newAccB = vector::ShapeCastOp::create(rewriter, loc, accTy, shuffleHi);
267
268 rewriter.replaceUsesWithIf(
269 opA->getResult(0), newAccA.getResult(), [&](OpOperand &use) {
270 return isa<vector::ContractionOp, scf::ForOp>(use.getOwner());
271 });
272
273 rewriter.replaceUsesWithIf(
274 opB->getResult(0), newAccB.getResult(), [&](OpOperand &use) {
275 return isa<vector::ContractionOp, scf::ForOp>(use.getOwner());
276 });
277
278 return success();
279}
280
281// This function shuffles the vectors written by vector.contract operation
282// as a flat layout structure before they are stored.
284 Operation *opA, Operation *opB,
285 int64_t nonUnitDimAcc,
286 VectorType accTy) {
287 // Helper to extract vector operand from write-like ops
288 auto getWrittenVector = [](Operation *op) -> Value {
289 if (auto write = dyn_cast<vector::TransferWriteOp>(op))
290 return write.getVector();
291 if (auto store = dyn_cast<vector::StoreOp>(op))
292 return store.getValueToStore();
293 return nullptr;
294 };
295
296 Value vecA = getWrittenVector(opA);
297 Value vecB = getWrittenVector(opB);
298
299 if (!vecA || !vecB)
300 return failure();
301
302 // Decide insertion point and location
303 Operation *insertBefore = opA->isBeforeInBlock(opB) ? opA : opB;
304
305 rewriter.setInsertionPoint(insertBefore);
306 Location loc = insertBefore->getLoc();
307
308 auto elemTy = accTy.getElementType();
309 auto flatTy = VectorType::get(nonUnitDimAcc, elemTy);
310
311 // Flatten vectors
312 auto castA = vector::ShapeCastOp::create(rewriter, loc, flatTy, vecA);
313 auto castB = vector::ShapeCastOp::create(rewriter, loc, flatTy, vecB);
314
315 // TODO: derive shuffle masks instead of hard-coding
316 auto masks = getShuffleMasks(nonUnitDimAcc);
317
318 auto shuffledLo = vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
319 castB, masks.maskLo);
320 auto shuffledHi = vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
321 castB, masks.maskHi);
322
323 // Cast back to accumulator type
324 auto newVecA = vector::ShapeCastOp::create(rewriter, loc, accTy, shuffledLo);
325 auto newVecB = vector::ShapeCastOp::create(rewriter, loc, accTy, shuffledHi);
326
327 // Update write operands in place
328 opA->setOperand(0, newVecA.getResult());
329 opB->setOperand(0, newVecB.getResult());
330
331 return success();
332}
333
334// Return true if vector.contract operations matches on below conditions:
335// (1) - the unitDim operand Lhs or Rhs should be same,
336// (2) - the defining source memref should be same for nonUnitDim
337// operation,
338// (3) - the nonUnit dim offset difference between the
339// vector.contracts should be 8 or 16.
340bool validatePairVectorContract(vector::ContractionOp contractOp,
341 vector::ContractionOp pairContOp,
342 bool rhsHasMultipleNonUnitDims,
343 int64_t nonUnitDimValue) {
344 if (rhsHasMultipleNonUnitDims &&
345 !(contractOp.getLhs() == pairContOp.getLhs()))
346 return false;
347
348 if (!rhsHasMultipleNonUnitDims &&
349 !(contractOp.getRhs() == pairContOp.getRhs()))
350 return false;
351
352 auto nonUnitOperand =
353 rhsHasMultipleNonUnitDims ? contractOp.getRhs() : contractOp.getLhs();
354 auto nonUnitOperandPairContOp =
355 rhsHasMultipleNonUnitDims ? pairContOp.getRhs() : pairContOp.getLhs();
356
357 Value srcBuff;
359 llvm::TypeSwitch<Operation *>(nonUnitOperand.getDefiningOp())
360 .Case<vector::TransferReadOp, vector::LoadOp>([&](auto readOp) {
361 srcBuff = readOp.getOperand(0);
362 indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
363 readOp.getIndices().end());
364 })
365 .Case<vector::ShapeCastOp>([&](vector::ShapeCastOp op) {
366 srcBuff = op.getSource();
367 indexVals.clear();
368 });
369
370 Value srcBuffPairContOp;
371 SmallVector<OpFoldResult> indexValsPairContOp;
372 llvm::TypeSwitch<Operation *>(nonUnitOperandPairContOp.getDefiningOp())
373 .Case<vector::TransferReadOp, vector::LoadOp>([&](auto readOp) {
374 srcBuffPairContOp = readOp.getOperand(0);
375 indexValsPairContOp = SmallVector<OpFoldResult>(
376 readOp.getIndices().begin(), readOp.getIndices().end());
377 })
378 .Case<vector::ShapeCastOp>([&](vector::ShapeCastOp op) {
379 srcBuffPairContOp = op.getSource();
380 indexVals.clear();
381 });
382
383 if (!srcBuff || !srcBuffPairContOp)
384 return false;
385
386 auto shuffleLw = srcBuff.getDefiningOp<vector::ShuffleOp>();
387 auto shuffleHw = srcBuffPairContOp.getDefiningOp<vector::ShuffleOp>();
388
389 if (shuffleLw && shuffleHw)
390 return shuffleLw.getV1() == shuffleHw.getV1() &&
391 shuffleLw.getV2() == shuffleHw.getV2();
392
393 if (srcBuff != srcBuffPairContOp)
394 return false;
395
396 for (size_t i = 0; i < indexVals.size(); i++) {
397 auto v0 = getConstantIntValue(indexVals[i]);
398 auto v1 = getConstantIntValue(indexValsPairContOp[i]);
399
400 if (!v0 || !v1)
401 return false;
402
403 if (*v1 == *v0)
404 continue;
405
406 if ((*v1 - *v0) != nonUnitDimValue)
407 return false;
408 }
409
410 return true;
411}
412
413} // namespace x86vector
414} // namespace mlir
return success()
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
AffineExpr getResult(unsigned idx) const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents an operand of an operation.
Definition Value.h:257
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:350
void setOperand(unsigned idx, Value value)
Definition Operation.h:351
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
result_range getResults()
Definition Operation.h:415
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
unsigned getNumUses() const
This method computes the number of uses of this Value.
Definition Value.cpp:52
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
Operation * traceToVectorReadLikeParentOperation(Value v)
static FailureOr< SmallVector< mlir::utils::IteratorType > > inferIteratorsFromOutMap(AffineMap map)
Operation * traceToVectorWriteLikeUserOperation(Value v)
ShuffleMasks getShuffleMasks(int64_t nonUnitDimAcc)
bool validatePairVectorContract(vector::ContractionOp contractOp, vector::ContractionOp pairContOp, bool rhsHasMultipleNonUnitDims, int64_t nonUnitDimValue)
LogicalResult shuffleBeforeWriteLikeOp(PatternRewriter &rewriter, Operation *opA, Operation *opB, int64_t nonUnitDimAcc, VectorType accTy)
LogicalResult shuffleAfterReadLikeOp(PatternRewriter &rewriter, Operation *opA, Operation *opB, vector::ContractionOp contractA, vector::ContractionOp contractB, int64_t nonUnitDimAcc, VectorType accTy)
bool isInVnniLayout(Operation *op, llvm::ArrayRef< AffineMap > indexingMaps, std::optional< unsigned > blockingFactor=std::nullopt)
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
llvm::ArrayRef< int64_t > maskLo
llvm::ArrayRef< int64_t > maskHi