MLIR  22.0.0git
VectorToAMX.cpp
Go to the documentation of this file.
1 //===- VectorToAMX.cpp - Convert vector to AMX dialect ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
19 #include "mlir/IR/Builders.h"
20 #include "mlir/Pass/Pass.h"
22 
23 #include "llvm/Support/DebugLog.h"
24 
25 #include <numeric>
26 
27 namespace mlir {
28 #define GEN_PASS_DEF_CONVERTVECTORTOAMX
29 #include "mlir/Conversion/Passes.h.inc"
30 } // namespace mlir
31 
32 using namespace mlir;
33 
34 #define DEBUG_TYPE "vector-to-amx"
35 
36 namespace {
37 
38 /// Return true if vector shape is compatible with AMX tiles.
39 /// The validation accounts for VNNI packing.
40 static bool verifyAmxShape(VectorType vec) {
41  // Check overall shape:
42  // - 2D for plain layout input or output
43  // - 3D for VNNI packed input
44  if (vec.getRank() != 2 && vec.getRank() != 3)
45  return false;
46 
47  ArrayRef<int64_t> shape = vec.getShape();
48  int64_t rows = shape[0];
49  int64_t cols = shape[1];
50  unsigned elemBitWidth = vec.getElementType().getIntOrFloatBitWidth();
51 
52  // 3D shape indicates VNNI packed layout.
53  if (vec.getRank() == 3) {
54  int64_t vnniFactor = 32 / elemBitWidth;
55  if (shape.back() != vnniFactor) {
56  LDBG() << "invalid VNNI packing factor";
57  return false;
58  }
59  cols *= vnniFactor;
60  }
61 
62  // AMX tile supports up to 16 rows of 64 bytes each.
63  constexpr unsigned maxRows = 16;
64  constexpr unsigned maxBitsPerRow = 64 * 8;
65  return rows <= maxRows && (cols * elemBitWidth) <= maxBitsPerRow;
66 }
67 
68 /// Check if contraction operands are in AMX-compatible packed VNNI layout.
69 static LogicalResult isAmxVnniLayout(PatternRewriter &rewriter,
70  vector::ContractionOp contractOp) {
71  VectorType accType = dyn_cast<VectorType>(contractOp.getAcc().getType());
72  if (!accType || accType.getRank() != 2)
73  return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector");
74 
75  // Expect 3D inputs for VNNI packed data.
76  VectorType lhsType = contractOp.getLhs().getType();
77  VectorType rhsType = contractOp.getRhs().getType();
78  if (lhsType.getRank() != 3 || rhsType.getRank() != 3)
79  return rewriter.notifyMatchFailure(contractOp,
80  "Expects lhs and rhs 3D vectors");
81 
82  // Check if shapes are compatible with AMX tile.
83  if (!verifyAmxShape(lhsType) || !verifyAmxShape(rhsType) ||
84  !verifyAmxShape(accType))
85  return rewriter.notifyMatchFailure(contractOp, "Invalid operand shape");
86 
87  // Validate affine maps.
88  //
89  // Iterators can be ordered arbitrarily. Indexing map positions are based on
90  // operands' target shapes.
91  // The matrix layouts must match the following:
92  // - matrix A - [M]x[K/vnniFactor]x[vnniFactor]
93  // - matrix B - [K/vnniFactor]x[N]x[vnniFactor]
94  // - matrix C - [M]x[N]
95  SmallVector<AffineMap, 4> indexingMaps = contractOp.getIndexingMapsArray();
96  AffineMap mapA = indexingMaps[0];
97  AffineMap mapB = indexingMaps[1];
98  if (mapA.getNumInputs() != 4 || mapA.getNumResults() != 3 ||
99  mapB.getNumResults() != 3)
100  return rewriter.notifyMatchFailure(contractOp,
101  "Invalid input indexing maps");
102  FailureOr<linalg::ContractionDimensions> dims =
103  linalg::inferContractionDims(indexingMaps);
104  if (failed(dims))
105  return rewriter.notifyMatchFailure(contractOp,
106  "Failed to infer contraction dims");
107  // Two reduction dimensions are expected:
108  // - one for the K dimension
109  // - one for the VNNI factor
110  if (dims->k.size() != 2)
111  return rewriter.notifyMatchFailure(contractOp,
112  "Expected two reduction dims");
113  assert(dims->m.size() == 1 && dims->n.size() == 1 &&
114  "Invalid parallel contraction dims");
115 
116  SmallVector<vector::IteratorType> iteratorTypes =
117  contractOp.getIteratorTypesArray();
118  // Check VNNI dim maps - the innermost dim for A and B inputs.
119  auto vnniDimA = dyn_cast<AffineDimExpr>(mapA.getResult(2));
120  auto vnniDimB = dyn_cast<AffineDimExpr>(mapB.getResult(2));
121  if (!vnniDimA || !vnniDimB || vnniDimA != vnniDimB ||
122  iteratorTypes[vnniDimA.getPosition()] != vector::IteratorType::reduction)
123  return rewriter.notifyMatchFailure(contractOp, "Invalid VNNI dim map");
124  // Check K dim maps - non-transposed row-major layout.
125  auto redDimA = dyn_cast<AffineDimExpr>(mapA.getResult(1));
126  auto redDimB = dyn_cast<AffineDimExpr>(mapB.getResult(0));
127  if (!redDimA || !redDimB || redDimA != redDimB ||
128  iteratorTypes[redDimA.getPosition()] != vector::IteratorType::reduction)
129  return rewriter.notifyMatchFailure(contractOp, "Invalid K dim map");
130  // Check M and N dim maps - map to non-transposed output.
131  AffineMap mapC = indexingMaps[2];
132  auto mDimC = dyn_cast<AffineDimExpr>(mapC.getResult(0));
133  auto nDimC = dyn_cast<AffineDimExpr>(mapC.getResult(1));
134  if (!mDimC || !nDimC)
135  return rewriter.notifyMatchFailure(contractOp, "Invalid acc maps");
136  auto parallelDimA = dyn_cast<AffineDimExpr>(mapA.getResult(0));
137  if (!parallelDimA ||
138  iteratorTypes[parallelDimA.getPosition()] !=
139  vector::IteratorType::parallel ||
140  parallelDimA != mDimC)
141  return rewriter.notifyMatchFailure(contractOp, "Invalid M dim map");
142  auto parallelDimB = dyn_cast<AffineDimExpr>(mapB.getResult(1));
143  if (!parallelDimB ||
144  iteratorTypes[parallelDimB.getPosition()] !=
145  vector::IteratorType::parallel ||
146  parallelDimB != nDimC)
147  return rewriter.notifyMatchFailure(contractOp, "Invalid N dim map");
148 
149  return success();
150 }
151 
152 /// Validate contraction operands for AMX lowering.
153 static LogicalResult validateOperands(PatternRewriter &rewriter,
154  vector::ContractionOp contractOp) {
155  VectorType accType = dyn_cast<VectorType>(contractOp.getAcc().getType());
156  if (!accType)
157  return rewriter.notifyMatchFailure(contractOp, "Expects vector acc");
158 
159  // Check if operand types are compatible with AMX compute ops.
160  bool validElemTypes = false;
161  Type lhsElemType = contractOp.getLhs().getType().getElementType();
162  Type rhsElemType = contractOp.getRhs().getType().getElementType();
163  Type accElemType = accType.getElementType();
164  if (accElemType.isInteger(32)) {
165  validElemTypes = lhsElemType.isInteger(8) && rhsElemType.isInteger(8);
166  } else if (accElemType.isF32()) {
167  validElemTypes = (lhsElemType.isF16() && rhsElemType.isF16()) ||
168  (lhsElemType.isBF16() && rhsElemType.isBF16());
169  }
170  if (!validElemTypes)
171  return rewriter.notifyMatchFailure(contractOp,
172  "Invalid combination of operand types");
173 
174  if (failed(isAmxVnniLayout(rewriter, contractOp)))
175  return failure();
176 
177  return success();
178 }
179 
180 /// Collapse the two innermost dimensions together.
181 static TypedValue<MemRefType> collapseLastDim(PatternRewriter &rewriter,
182  TypedValue<MemRefType> memref) {
183  int64_t rank = memref.getType().getRank();
184  SmallVector<ReassociationIndices> reassocIndices;
185  for (auto i : llvm::seq<int64_t>(0, rank - 2))
186  reassocIndices.push_back({i});
187  reassocIndices.push_back({rank - 2, rank - 1});
188  return memref::CollapseShapeOp::create(rewriter, memref.getLoc(), memref,
189  reassocIndices);
190 }
191 
192 /// Attempt to create an AMX tile load/store operation equivalent to the given
193 /// vector transfer `xfer` op.
194 /// This approach allows to skip longer route through registers and a temporary
195 /// buffer otherwise required to move data to/from an AMX tile.
196 static Operation *
197 loadStoreFromTransfer(PatternRewriter &rewriter,
198  VectorTransferOpInterface xferOp, bool isPacked,
199  TypedValue<amx::TileType> tileToStore = nullptr) {
200  if (!xferOp || !isa<vector::TransferReadOp, vector::TransferWriteOp>(xferOp))
201  return nullptr;
202  if (xferOp.hasOutOfBoundsDim() ||
203  !xferOp.getPermutationMap().isMinorIdentity())
204  return nullptr;
205 
206  // Extra checks in case of a write op.
207  // Stores must not be packed.
208  if (isa<vector::TransferWriteOp>(xferOp) &&
209  (!tileToStore || isPacked ||
210  tileToStore.getType().getShape() != xferOp.getVectorType().getShape()))
211  return nullptr;
212 
213  // Check for a memref source buffer.
214  // AMX data transfer requires at least 2D shape to correctly
215  // infer stride between rows.
216  Value base = xferOp.getBase();
217  auto memTy = dyn_cast<MemRefType>(base.getType());
218  int64_t memRank = memTy.getRank();
219  if (!memTy || memRank < 2)
220  return nullptr;
221 
222  // Check that the source buffer has enough contiguous elements to load whole
223  // AMX tile row.
224  //
225  // To ensure correctness, the validation is conservative and expects the
226  // buffer's innermost dimensions to be statically known, equal to or larger
227  // than the vector row length, and equal to the VNNI dimension if applicable.
228  //
229  // This check could be relaxed to accept more arbitrarily shaped buffers as
230  // long as there are enough contiguous elements to load a whole row.
231  if (!memTy.areTrailingDimsContiguous(isPacked ? 2 : 1))
232  return nullptr;
233  VectorType vecTy = xferOp.getVectorType();
234  ArrayRef<int64_t> vecShape = vecTy.getShape();
235  ArrayRef<int64_t> memShape = memTy.getShape();
236  if (memShape.back() == ShapedType::kDynamic ||
237  memShape.back() < vecShape.back())
238  return nullptr;
239  if (isPacked &&
240  (memShape.back() != vecShape.back() ||
241  memShape[memShape.size() - 2] == ShapedType::kDynamic ||
242  memShape[memShape.size() - 2] < vecShape[vecShape.size() - 2]))
243  return nullptr;
244 
245  // Load values directly from the buffer to an AMX tile.
246  PatternRewriter::InsertionGuard g(rewriter);
247  rewriter.setInsertionPoint(xferOp);
248  Location loc = xferOp.getLoc();
249 
250  // Create a subview of the source buffer based on the transfer op to resolve
251  // offsets.
252  SmallVector<OpFoldResult> strides(memRank, rewriter.getIndexAttr(1));
253  int64_t vecRank = vecTy.getRank();
254  assert(memRank >= vecRank &&
255  "Expects buffer to be the same or greater rank than vector");
256  SmallVector<int64_t> shape(memRank - vecRank, 1);
257  shape.append(vecShape.begin(), vecShape.end());
259  memref::SubViewOp::create(
260  rewriter, loc, base, getAsOpFoldResult(xferOp.getIndices()),
261  getAsOpFoldResult(rewriter.getI64ArrayAttr(shape)), strides)
262  .getResult();
263 
264  // Collapse the VNNI dimension in case of packing.
265  if (isPacked)
266  src = collapseLastDim(rewriter, src);
267  int64_t rows = vecShape[0];
268  int64_t cols = std::accumulate(vecShape.begin() + 1, vecShape.end(), 1,
269  std::multiplies<int64_t>());
270  auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType());
271 
272  Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
273  SmallVector<Value> tileIndicides(src.getType().getRank(), zeroIndex);
274 
275  Operation *amxTileOp = nullptr;
276  if (isa<vector::TransferReadOp>(xferOp)) {
277  amxTileOp =
278  amx::TileLoadOp::create(rewriter, loc, tileType, src, tileIndicides);
279  } else if (isa<vector::TransferWriteOp>(xferOp)) {
280  amxTileOp = amx::TileStoreOp::create(rewriter, loc, src, tileIndicides,
281  tileToStore);
282  } else {
283  llvm_unreachable("unsupported vector transfer op");
284  }
285 
286  return amxTileOp;
287 }
288 
289 /// Attempt to create an AMX tile load operation equivalent to the given
290 /// vector transfer `readOp`.
291 /// Returns loaded AMX tile if successful.
292 static FailureOr<TypedValue<amx::TileType>>
293 loadFromTransfer(PatternRewriter &rewriter, vector::TransferReadOp readOp,
294  bool isPacked) {
295  amx::TileLoadOp loadOp = dyn_cast_if_present<amx::TileLoadOp>(
296  loadStoreFromTransfer(rewriter, readOp, isPacked));
297  if (!loadOp)
298  return failure();
299  return loadOp.getRes();
300 }
301 
302 /// Attempt to create an AMX tile store operation equivalent to the given
303 /// vector transfer `writeOp`.
304 static LogicalResult storeFromTransfer(PatternRewriter &rewriter,
305  vector::TransferWriteOp writeOp,
306  TypedValue<amx::TileType> tileToStore) {
307  return success(loadStoreFromTransfer(rewriter, writeOp, /*isPacked=*/false,
308  tileToStore));
309 }
310 
311 /// Load vector values to an AMX tile.
312 static TypedValue<amx::TileType> loadTile(PatternRewriter &rewriter,
314  Location loc = vec.getLoc();
315 
316  VectorType vecTy = vec.getType();
317  bool isPacked = vecTy.getRank() == 3;
318 
319  // Try to load tile directly from vector producer's buffer.
320  auto readOp = vec.getDefiningOp<vector::TransferReadOp>();
321  FailureOr<TypedValue<amx::TileType>> tile =
322  loadFromTransfer(rewriter, readOp, isPacked);
323  if (succeeded(tile))
324  return *tile;
325 
326  // Transfer the vector to a tile through an intermediate buffer.
327  Value buf = memref::AllocaOp::create(
328  rewriter, loc, MemRefType::get(vecTy.getShape(), vecTy.getElementType()));
329  Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
330  SmallVector<Value> indices(vecTy.getRank(), zeroIndex);
331  vector::TransferWriteOp::create(rewriter, loc, vec, buf, indices);
332 
333  // Collapse the VNNI dimension in case of packing.
334  if (isPacked)
335  buf = collapseLastDim(rewriter, cast<TypedValue<MemRefType>>(buf));
336 
337  ArrayRef<int64_t> shape = vecTy.getShape();
338  int64_t rows = shape[0];
339  int64_t cols = std::accumulate(shape.begin() + 1, shape.end(), 1,
340  std::multiplies<int64_t>());
341  auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType());
342 
343  return amx::TileLoadOp::create(rewriter, loc, tileType, buf,
344  {zeroIndex, zeroIndex});
345 }
346 
347 /// Store an AMX tile in a vector.
348 static TypedValue<VectorType> storeTile(PatternRewriter &rewriter,
350  Location loc = tile.getLoc();
351 
352  // Transfer the tile to a vector through an intermediate buffer.
353  amx::TileType tileTy = tile.getType();
354  Value buf = memref::AllocaOp::create(
355  rewriter, loc,
356  MemRefType::get(tileTy.getShape(), tileTy.getElementType()));
357  Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
358  SmallVector<Value> indices(2, zeroIndex);
359  amx::TileStoreOp::create(rewriter, loc, buf, indices, tile);
360 
361  auto vecTy = VectorType::get(tileTy.getShape(), tileTy.getElementType());
362  return vector::TransferReadOp::create(rewriter, loc, vecTy, buf, indices, {});
363 }
364 
365 struct ContractionToAMX : public OpRewritePattern<vector::ContractionOp> {
367 
368  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
369  PatternRewriter &rewriter) const override {
370  Location loc = contractOp.getLoc();
371 
372  if (contractOp.getKind() != vector::CombiningKind::ADD)
373  return rewriter.notifyMatchFailure(contractOp,
374  "Expects add combining kind");
375  if (failed(validateOperands(rewriter, contractOp)))
376  return failure();
377 
378  TypedValue<amx::TileType> lhsTile = loadTile(rewriter, contractOp.getLhs());
379  TypedValue<amx::TileType> rhsTile = loadTile(rewriter, contractOp.getRhs());
380  auto acc = dyn_cast<TypedValue<VectorType>>(contractOp.getAcc());
381  assert(acc && "Invalid accumulator type");
382  TypedValue<amx::TileType> accTile = loadTile(rewriter, acc);
383 
385  if (acc.getType().getElementType().isFloat()) {
386  tileMul = amx::TileMulFOp::create(rewriter, loc, accTile.getType(),
387  lhsTile, rhsTile, accTile);
388  } else {
389  tileMul = amx::TileMulIOp::create(rewriter, loc, accTile.getType(),
390  lhsTile, rhsTile, accTile);
391  }
392 
393  // If the contraction result is only written back to memory, try to replace
394  // the vector op with an AMX store directly.
395  Value res = contractOp.getResult();
396  if (res.hasOneUse()) {
397  auto writeOp = dyn_cast<vector::TransferWriteOp>(*res.getUsers().begin());
398  LogicalResult storeRes = storeFromTransfer(rewriter, writeOp, tileMul);
399  if (succeeded(storeRes)) {
400  rewriter.eraseOp(writeOp);
401  rewriter.eraseOp(contractOp);
402  return success();
403  }
404  }
405 
406  // Load the result back into a vector.
407  Value newResult = storeTile(rewriter, tileMul);
408  rewriter.replaceOp(contractOp, newResult);
409 
410  return success();
411  }
412 };
413 
414 struct ConvertVectorToAMXPass
415  : public impl::ConvertVectorToAMXBase<ConvertVectorToAMXPass> {
416  void runOnOperation() override {
417  MLIRContext &ctx = getContext();
420  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
421  return signalPassFailure();
422  }
423 };
424 
425 } // namespace
426 
428  patterns.add<ContractionToAMX>(patterns.getContext());
429 }
static MLIRContext * getContext(OpFoldResult val)
int64_t cols
int64_t rows
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
unsigned getNumResults() const
Definition: AffineMap.cpp:398
unsigned getNumInputs() const
Definition: AffineMap.cpp:399
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:407
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:107
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:280
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:519
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isF16() const
Definition: Types.cpp:38
bool isBF16() const
Definition: Types.cpp:37
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
user_range getUsers() const
Definition: Value.h:218
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:197
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
void populateVectorToAMXConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert from the vector to AMX ops.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition: Utils.cpp:1285
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319