MLIR 23.0.0git
X86Utils.cpp
Go to the documentation of this file.
1//===- X86Utils.cpp - MLIR Utilities for X86Ops -------------------------===//
2//
3// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
17#include "mlir/IR/Types.h"
18
19#include "llvm/ADT/TypeSwitch.h"
20#include "llvm/Support/Casting.h"
21
22#include "llvm/ADT/ArrayRef.h"
23#include <cassert>
24
25namespace mlir {
26namespace x86 {
27
28static FailureOr<SmallVector<mlir::utils::IteratorType>>
30 if (!map.isProjectedPermutation())
31 return failure();
33 map.getNumDims(), mlir::utils::IteratorType::reduction);
34 for (auto expr : map.getResults())
35 if (auto dim = dyn_cast<AffineDimExpr>(expr))
36 iterators[dim.getPosition()] = mlir::utils::IteratorType::parallel;
37 return iterators;
38}
39
40// Returns true if the operation is in VNNI layout.
41// Optionally, the check can be constrained to a specific VNNI blocking factor.
43 std::optional<unsigned> blockingFactor) {
44 // Narrow down type operations - VNNI only applies to contractions.
45 FailureOr<linalg::ContractionDimensions> dims =
46 linalg::inferContractionDims(indexingMaps);
47 if (failed(dims))
48 return false;
49
50 auto matA = op->getOperand(0);
51 auto matB = op->getOperand(1);
52 auto typeA = dyn_cast<ShapedType>(matA.getType());
53 auto typeB = dyn_cast<ShapedType>(matB.getType());
54 unsigned rankA = typeA.getRank();
55 unsigned rankB = typeB.getRank();
56 // VNNI format requires at least 1 parallel and 2 reduction dimensions.
57 if (rankA < 3 || rankB < 3)
58 return false;
59
60 // At least two reduction dimensions are expected:
61 // one for the VNNI factor and one for the K dimension
62 if (dims->k.size() < 2)
63 return false;
64
65 // Validate affine maps - VNNI computation should be defined by the two
66 // innermost reduction iterators.
67 // The input matrix dimensions layout must match the following:
68 // - matrix A - [...][K/vnniFactor][vnniFactor]
69 // - matrix B - [...][K/vnniFactor][N][vnniFactor]
70 auto maybeIters = inferIteratorsFromOutMap(indexingMaps[2] /* outs */);
71 if (failed(maybeIters))
72 return false;
73 SmallVector<mlir::utils::IteratorType> iteratorTypes = *maybeIters;
74 AffineMap mapA = indexingMaps[0];
75 AffineMap mapB = indexingMaps[1];
76
77 auto vnniDimA = dyn_cast<AffineDimExpr>(mapA.getResult(rankA - 1));
78 auto vnniDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 1));
79 if (!vnniDimA || !vnniDimB || vnniDimA != vnniDimB ||
80 iteratorTypes[vnniDimA.getPosition()] !=
81 mlir::utils::IteratorType::reduction)
82 return false;
83 auto redDimA = dyn_cast<AffineDimExpr>(mapA.getResult(rankA - 2));
84 auto redDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 3));
85 if (!redDimA || !redDimB || redDimA != redDimB ||
86 iteratorTypes[redDimA.getPosition()] !=
87 mlir::utils::IteratorType::reduction)
88 return false;
89 auto parallelDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 2));
90 if (!parallelDimB || iteratorTypes[parallelDimB.getPosition()] !=
91 mlir::utils::IteratorType::parallel)
92 return false;
93
94 // VNNI factor must be:
95 // - the innermost inputs' dimension
96 // - statically known
97 // - multiple of 2 or equal to the specified factor
98 auto vnniDimSize = typeB.getShape().back();
99 if (vnniDimSize == ShapedType::kDynamic || vnniDimSize == 0 ||
100 vnniDimSize % 2 != 0)
101 return false;
102 if (typeA.getShape().back() != vnniDimSize)
103 return false;
104 if (blockingFactor && vnniDimSize != *blockingFactor)
105 return false;
106
107 // The split reduction dimension size should also match.
108 if (typeA.getShape().end()[-2] != typeB.getShape().end()[-3])
109 return false;
110
111 return true;
112}
113
118
119inline ShuffleMasks getShuffleMasks(int64_t nonUnitDimAcc, bool isInt8Avx2) {
120 // We only support these two layouts for now.
121 assert((nonUnitDimAcc == 8 || nonUnitDimAcc == 16) &&
122 "Unsupported nonUnitDimAcc value");
123
124 // Do interleaving between two <8xf32> targeting AVX2.
125 static constexpr int64_t maskLo8[] = {0, 8, 1, 9, 2, 10, 3, 11};
126 static constexpr int64_t maskHi8[] = {4, 12, 5, 13, 6, 14, 7, 15};
127
128 // Do interleaving between two <8xi32> targeting AVX2.
129 static constexpr int64_t maskLo8_avx2_int8[] = {0, 1, 2, 3, 8, 9, 10, 11};
130 static constexpr int64_t maskHi8_avx2_int8[] = {4, 5, 6, 7, 12, 13, 14, 15};
131
132 // Shuffle two <16xf32/i32> as below targeting AVX512.
133 static constexpr int64_t maskLo16[] = {0, 1, 2, 3, 16, 17, 18, 19,
134 4, 5, 6, 7, 20, 21, 22, 23};
135 static constexpr int64_t maskHi16[] = {8, 9, 10, 11, 24, 25, 26, 27,
136 12, 13, 14, 15, 28, 29, 30, 31};
137
138 if (nonUnitDimAcc == 16)
139 return {maskLo16, maskHi16};
140
141 if (isInt8Avx2)
142 return {maskLo8_avx2_int8, maskHi8_avx2_int8};
143
144 return {maskLo8, maskHi8};
145}
146
147// This function walks backward from a value to locate its originating
148// vector read-like operation (`vector.transfer_read` or `vector.load`).
149// It follows simple forwarding through unary ops and across `scf.for`
150// loop iter-arguments, while stopping if layout-transforming ops such
151// as `shape_cast` or `shuffle` are encountered. The traversal returns
152// the read-like defining operation or `nullptr` if no valid source
153// is found.
155 while (true) {
156 // Case 1: Value defined by an operation
157 if (Operation *defOp = v.getDefiningOp()) {
158 if (isa<vector::TransferReadOp, vector::LoadOp>(defOp))
159 return defOp;
160
161 return nullptr;
162 }
163
164 // Case 2: BlockArgument (scf.for iter_arg)
165 if (auto barg = dyn_cast<BlockArgument>(v)) {
166 auto *parentOp = barg.getOwner()->getParentOp();
167
168 if (auto forOp = dyn_cast<scf::ForOp>(parentOp)) {
169 unsigned argNum = barg.getArgNumber();
170
171 // arg0 = induction variable (not an iter_arg)
172 if (argNum == 0)
173 return nullptr;
174
175 unsigned iterIdx = argNum - 1;
176 v = forOp.getInitArgs()[iterIdx];
177 continue;
178 }
179
180 return nullptr;
181 }
182
183 return nullptr;
184 }
185}
186
187// This function recursively traces a value through its uses to find
188// a downstream vector write-like operation (`vector.transfer_write`
189// or `vector.store`). It transparently follows values across `scf.for`
190// and `scf.yield` boundaries while stopping if layout-altering ops such
191// as `shape_cast` or `shuffle` are encountered. The traversal returns
192// the matching write-like user. Returns `nullptr` if none is found or
193// the value has multiple users.
195
196 if (v.getNumUses() > 1)
197 return nullptr;
198
199 for (OpOperand &use : v.getUses()) {
200 Operation *user = use.getOwner();
201
202 // --- TERMINAL OPS ---
203 if (isa<vector::TransferWriteOp>(user) || isa<vector::StoreOp>(user))
204 return user;
205
206 if (isa<vector::ShapeCastOp, vector::ShuffleOp>(user))
207 return nullptr;
208
209 // --- SCF YIELD ---
210 if (auto yield = dyn_cast<scf::YieldOp>(user)) {
211 Operation *parent = yield->getParentOp();
212 unsigned idx = use.getOperandNumber();
213 if (auto *res =
215 return res;
216 continue;
217 }
218
219 // --- SCF FOR ---
220 if (auto forOp = dyn_cast<scf::ForOp>(user)) {
221 unsigned idx = use.getOperandNumber();
222 if (auto *res = traceToVectorWriteLikeUserOperation(forOp.getResult(idx)))
223 return res;
224 continue;
225 }
226
227 // --- GENERIC CASE ---
228 for (Value res : user->getResults()) {
229 if (auto *found = traceToVectorWriteLikeUserOperation(res))
230 return found;
231 }
232 }
233
234 return nullptr;
235}
236
237// This function packs the accumulator of two flat BF16 vector.contract
238// operations into VNNI packed and are then replaced in their respective
239// contraction ops, enabling post-read layout or packing transformations.
240// TODO: replace all use with the packed value along with contration
241// and for op.
243 Operation *opB,
244 vector::ContractionOp contractA,
245 vector::ContractionOp contractB,
246 int64_t nonUnitDimAcc, VectorType accTy) {
247
248 if (!isa<vector::TransferReadOp, vector::LoadOp>(opA) ||
249 !isa<vector::TransferReadOp, vector::LoadOp>(opB)) {
250 return failure();
251 }
252
253 Operation *insertAfter = opA->isBeforeInBlock(opB) ? opB : opA;
254
255 rewriter.setInsertionPointAfter(insertAfter);
256 Location loc = insertAfter->getLoc();
257
258 auto elemTy = accTy.getElementType();
259 auto flatTy = VectorType::get(nonUnitDimAcc, elemTy);
260
261 auto castA =
262 vector::ShapeCastOp::create(rewriter, loc, flatTy, opA->getResult(0));
263 auto castB =
264 vector::ShapeCastOp::create(rewriter, loc, flatTy, opB->getResult(0));
265
266 auto masks = getShuffleMasks(
267 nonUnitDimAcc, (elemTy.isSignlessInteger(32) && nonUnitDimAcc == 8));
268
269 auto shuffleLo = vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
270 castB, masks.maskLo);
271 auto shuffleHi = vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
272 castB, masks.maskHi);
273
274 auto newAccA = vector::ShapeCastOp::create(rewriter, loc, accTy, shuffleLo);
275 auto newAccB = vector::ShapeCastOp::create(rewriter, loc, accTy, shuffleHi);
276
277 rewriter.replaceUsesWithIf(
278 opA->getResult(0), newAccA.getResult(), [&](OpOperand &use) {
279 return isa<vector::ContractionOp, scf::ForOp>(use.getOwner());
280 });
281
282 rewriter.replaceUsesWithIf(
283 opB->getResult(0), newAccB.getResult(), [&](OpOperand &use) {
284 return isa<vector::ContractionOp, scf::ForOp>(use.getOwner());
285 });
286
287 return success();
288}
289
290// This function shuffles the vectors written by vector.contract operation
291// as a flat layout structure before they are stored.
293 Operation *opA, Operation *opB,
294 int64_t nonUnitDimAcc,
295 VectorType accTy) {
296 // Helper to extract vector operand from write-like ops
297 auto getWrittenVector = [](Operation *op) -> Value {
298 if (auto write = dyn_cast<vector::TransferWriteOp>(op))
299 return write.getVector();
300 if (auto store = dyn_cast<vector::StoreOp>(op))
301 return store.getValueToStore();
302 return nullptr;
303 };
304
305 Value vecA = getWrittenVector(opA);
306 Value vecB = getWrittenVector(opB);
307
308 if (!vecA || !vecB)
309 return failure();
310
311 // Decide insertion point and location
312 Operation *insertBefore = opA->isBeforeInBlock(opB) ? opA : opB;
313
314 rewriter.setInsertionPoint(insertBefore);
315 Location loc = insertBefore->getLoc();
316
317 auto elemTy = accTy.getElementType();
318 auto flatTy = VectorType::get(nonUnitDimAcc, elemTy);
319
320 // Flatten vectors
321 auto castA = vector::ShapeCastOp::create(rewriter, loc, flatTy, vecA);
322 auto castB = vector::ShapeCastOp::create(rewriter, loc, flatTy, vecB);
323
324 // TODO: derive shuffle masks instead of hard-coding
325 auto masks = getShuffleMasks(
326 nonUnitDimAcc, (elemTy.isSignlessInteger(32) && nonUnitDimAcc == 8));
327
328 auto shuffledLo = vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
329 castB, masks.maskLo);
330 auto shuffledHi = vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
331 castB, masks.maskHi);
332
333 // Cast back to accumulator type
334 auto newVecA = vector::ShapeCastOp::create(rewriter, loc, accTy, shuffledLo);
335 auto newVecB = vector::ShapeCastOp::create(rewriter, loc, accTy, shuffledHi);
336
337 // Update write operands in place via the rewriter to notify it of changes.
338 rewriter.modifyOpInPlace(opA,
339 [&]() { opA->setOperand(0, newVecA.getResult()); });
340 rewriter.modifyOpInPlace(opB,
341 [&]() { opB->setOperand(0, newVecB.getResult()); });
342
343 return success();
344}
345
346// Return true if vector.contract operations matches on below conditions:
347// (1) - the unitDim operand Lhs or Rhs should be same,
348// (2) - the defining source memref should be same for nonUnitDim
349// operation,
350// (3) - the nonUnit dim offset difference between the
351// vector.contracts should be 8 or 16.
352bool validatePairVectorContract(vector::ContractionOp contractOp,
353 vector::ContractionOp pairContOp,
354 bool rhsHasMultipleNonUnitDims,
355 int64_t nonUnitDimValue) {
356 if (contractOp == pairContOp)
357 return false;
358
359 if (rhsHasMultipleNonUnitDims &&
360 !(contractOp.getLhs() == pairContOp.getLhs()))
361 return false;
362
363 if (!rhsHasMultipleNonUnitDims &&
364 !(contractOp.getRhs() == pairContOp.getRhs()))
365 return false;
366
367 auto nonUnitOperand =
368 rhsHasMultipleNonUnitDims ? contractOp.getRhs() : contractOp.getLhs();
369 auto nonUnitOperandPairContOp =
370 rhsHasMultipleNonUnitDims ? pairContOp.getRhs() : pairContOp.getLhs();
371
372 Value srcBuff;
374 llvm::TypeSwitch<Operation *>(nonUnitOperand.getDefiningOp())
375 .Case<vector::TransferReadOp, vector::LoadOp>([&](auto readOp) {
376 srcBuff = readOp.getOperand(0);
377 indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
378 readOp.getIndices().end());
379 })
380 .Case<vector::ShapeCastOp>([&](vector::ShapeCastOp op) {
381 srcBuff = op.getSource();
382 indexVals.clear();
383 });
384
385 Value srcBuffPairContOp;
386 SmallVector<OpFoldResult> indexValsPairContOp;
387 llvm::TypeSwitch<Operation *>(nonUnitOperandPairContOp.getDefiningOp())
388 .Case<vector::TransferReadOp, vector::LoadOp>([&](auto readOp) {
389 srcBuffPairContOp = readOp.getOperand(0);
390 indexValsPairContOp = SmallVector<OpFoldResult>(
391 readOp.getIndices().begin(), readOp.getIndices().end());
392 })
393 .Case<vector::ShapeCastOp>([&](vector::ShapeCastOp op) {
394 srcBuffPairContOp = op.getSource();
395 indexVals.clear();
396 });
397
398 if (!srcBuff || !srcBuffPairContOp)
399 return false;
400
401 auto shuffleLw = srcBuff.getDefiningOp<vector::ShuffleOp>();
402 auto shuffleHw = srcBuffPairContOp.getDefiningOp<vector::ShuffleOp>();
403
404 if (shuffleLw && shuffleHw)
405 return shuffleLw.getV1() == shuffleHw.getV1() &&
406 shuffleLw.getV2() == shuffleHw.getV2();
407
408 if (srcBuff != srcBuffPairContOp)
409 return false;
410
411 bool oneConstantOffset = false;
412 for (size_t i = 0; i < indexVals.size(); i++) {
413
414 if (indexVals[i] == indexValsPairContOp[i])
415 continue;
416
417 auto v0 = getConstantIntValue(indexVals[i]);
418 auto v1 = getConstantIntValue(indexValsPairContOp[i]);
419
420 if (!v0 || !v1)
421 return false;
422
423 if ((*v1 - *v0) != nonUnitDimValue)
424 return false;
425
426 oneConstantOffset = true;
427 }
428
429 return oneConstantOffset;
430}
431
432} // namespace x86
433} // namespace mlir
return success()
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
AffineExpr getResult(unsigned idx) const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents an operand of an operation.
Definition Value.h:254
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:376
void setOperand(unsigned idx, Value value)
Definition Operation.h:377
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:252
result_range getResults()
Definition Operation.h:441
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
unsigned getNumUses() const
This method computes the number of uses of this Value.
Definition Value.cpp:52
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
LogicalResult shuffleBeforeWriteLikeOp(PatternRewriter &rewriter, Operation *opA, Operation *opB, int64_t nonUnitDimAcc, VectorType accTy)
Definition X86Utils.cpp:292
Operation * traceToVectorWriteLikeUserOperation(Value v)
Definition X86Utils.cpp:194
static FailureOr< SmallVector< mlir::utils::IteratorType > > inferIteratorsFromOutMap(AffineMap map)
Definition X86Utils.cpp:29
bool isInVnniLayout(Operation *op, llvm::ArrayRef< AffineMap > indexingMaps, std::optional< unsigned > blockingFactor=std::nullopt)
Definition X86Utils.cpp:42
Operation * traceToVectorReadLikeParentOperation(Value v)
Definition X86Utils.cpp:154
ShuffleMasks getShuffleMasks(int64_t nonUnitDimAcc, bool isInt8Avx2)
Definition X86Utils.cpp:119
LogicalResult shuffleAfterReadLikeOp(PatternRewriter &rewriter, Operation *opA, Operation *opB, vector::ContractionOp contractA, vector::ContractionOp contractB, int64_t nonUnitDimAcc, VectorType accTy)
Definition X86Utils.cpp:242
bool validatePairVectorContract(vector::ContractionOp contractOp, vector::ContractionOp pairContOp, bool rhsHasMultipleNonUnitDims, int64_t nonUnitDimValue)
Definition X86Utils.cpp:352
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
llvm::ArrayRef< int64_t > maskHi
Definition X86Utils.cpp:116
llvm::ArrayRef< int64_t > maskLo
Definition X86Utils.cpp:115