MLIR 22.0.0git
VectorToAMX.cpp
Go to the documentation of this file.
1//===- VectorToAMX.cpp - Convert vector to AMX dialect ----------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
19#include "mlir/IR/Builders.h"
20#include "mlir/Pass/Pass.h"
22
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/Support/DebugLog.h"
25
26#include <numeric>
27
28namespace mlir {
29#define GEN_PASS_DEF_CONVERTVECTORTOAMX
30#include "mlir/Conversion/Passes.h.inc"
31} // namespace mlir
32
33using namespace mlir;
34
35#define DEBUG_TYPE "vector-to-amx"
36
37namespace {
38
39/// Return true if vector shape is compatible with AMX tiles.
40/// The validation accounts for VNNI packing.
41static bool verifyAmxShape(VectorType vec) {
42 // Check overall shape:
43 // - 2D for plain layout input or output
44 // - 3D for VNNI packed input
45 if (vec.getRank() != 2 && vec.getRank() != 3)
46 return false;
47
48 ArrayRef<int64_t> shape = vec.getShape();
49 int64_t rows = shape[0];
50 int64_t cols = shape[1];
51 unsigned elemBitWidth = vec.getElementType().getIntOrFloatBitWidth();
52
53 // 3D shape indicates VNNI packed layout.
54 if (vec.getRank() == 3) {
55 int64_t vnniFactor = 32 / elemBitWidth;
56 if (shape.back() != vnniFactor) {
57 LDBG() << "invalid VNNI packing factor";
58 return false;
59 }
60 cols *= vnniFactor;
61 }
62
63 // AMX tile supports up to 16 rows of 64 bytes each.
64 constexpr unsigned maxRows = 16;
65 constexpr unsigned maxBitsPerRow = 64 * 8;
66 return rows <= maxRows && (cols * elemBitWidth) <= maxBitsPerRow;
67}
68
69/// Check if contraction operands are in AMX-compatible packed VNNI layout.
70static LogicalResult isAmxVnniLayout(PatternRewriter &rewriter,
71 vector::ContractionOp contractOp) {
72 VectorType accType = dyn_cast<VectorType>(contractOp.getAcc().getType());
73 if (!accType || accType.getRank() != 2)
74 return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector");
75
76 // Expect 3D inputs for VNNI packed data.
77 VectorType lhsType = contractOp.getLhs().getType();
78 VectorType rhsType = contractOp.getRhs().getType();
79 if (lhsType.getRank() != 3 || rhsType.getRank() != 3)
80 return rewriter.notifyMatchFailure(contractOp,
81 "Expects lhs and rhs 3D vectors");
82
83 // Check if shapes are compatible with AMX tile.
84 if (!verifyAmxShape(lhsType) || !verifyAmxShape(rhsType) ||
85 !verifyAmxShape(accType))
86 return rewriter.notifyMatchFailure(contractOp, "Invalid operand shape");
87
88 // Validate affine maps.
89 //
90 // Iterators can be ordered arbitrarily. Indexing map positions are based on
91 // operands' target shapes.
92 // The matrix layouts must match the following:
93 // - matrix A - [M]x[K/vnniFactor]x[vnniFactor]
94 // - matrix B - [K/vnniFactor]x[N]x[vnniFactor]
95 // - matrix C - [M]x[N]
96 SmallVector<AffineMap, 4> indexingMaps = contractOp.getIndexingMapsArray();
97 AffineMap mapA = indexingMaps[0];
98 AffineMap mapB = indexingMaps[1];
99 if (mapA.getNumInputs() != 4 || mapA.getNumResults() != 3 ||
100 mapB.getNumResults() != 3)
101 return rewriter.notifyMatchFailure(contractOp,
102 "Invalid input indexing maps");
103 FailureOr<linalg::ContractionDimensions> dims =
104 linalg::inferContractionDims(indexingMaps);
105 if (failed(dims))
106 return rewriter.notifyMatchFailure(contractOp,
107 "Failed to infer contraction dims");
108 // Two reduction dimensions are expected:
109 // - one for the K dimension
110 // - one for the VNNI factor
111 if (dims->k.size() != 2)
112 return rewriter.notifyMatchFailure(contractOp,
113 "Expected two reduction dims");
114 assert(dims->m.size() == 1 && dims->n.size() == 1 &&
115 "Invalid parallel contraction dims");
116
118 contractOp.getIteratorTypesArray();
119 // Check VNNI dim maps - the innermost dim for A and B inputs.
120 auto vnniDimA = dyn_cast<AffineDimExpr>(mapA.getResult(2));
121 auto vnniDimB = dyn_cast<AffineDimExpr>(mapB.getResult(2));
122 if (!vnniDimA || !vnniDimB || vnniDimA != vnniDimB ||
123 iteratorTypes[vnniDimA.getPosition()] != vector::IteratorType::reduction)
124 return rewriter.notifyMatchFailure(contractOp, "Invalid VNNI dim map");
125 // Check K dim maps - non-transposed row-major layout.
126 auto redDimA = dyn_cast<AffineDimExpr>(mapA.getResult(1));
127 auto redDimB = dyn_cast<AffineDimExpr>(mapB.getResult(0));
128 if (!redDimA || !redDimB || redDimA != redDimB ||
129 iteratorTypes[redDimA.getPosition()] != vector::IteratorType::reduction)
130 return rewriter.notifyMatchFailure(contractOp, "Invalid K dim map");
131 // Check M and N dim maps - map to non-transposed output.
132 AffineMap mapC = indexingMaps[2];
133 auto mDimC = dyn_cast<AffineDimExpr>(mapC.getResult(0));
134 auto nDimC = dyn_cast<AffineDimExpr>(mapC.getResult(1));
135 if (!mDimC || !nDimC)
136 return rewriter.notifyMatchFailure(contractOp, "Invalid acc maps");
137 auto parallelDimA = dyn_cast<AffineDimExpr>(mapA.getResult(0));
138 if (!parallelDimA ||
139 iteratorTypes[parallelDimA.getPosition()] !=
140 vector::IteratorType::parallel ||
141 parallelDimA != mDimC)
142 return rewriter.notifyMatchFailure(contractOp, "Invalid M dim map");
143 auto parallelDimB = dyn_cast<AffineDimExpr>(mapB.getResult(1));
144 if (!parallelDimB ||
145 iteratorTypes[parallelDimB.getPosition()] !=
146 vector::IteratorType::parallel ||
147 parallelDimB != nDimC)
148 return rewriter.notifyMatchFailure(contractOp, "Invalid N dim map");
149
150 return success();
151}
152
153/// Validate contraction operands for AMX lowering.
154static LogicalResult validateOperands(PatternRewriter &rewriter,
155 vector::ContractionOp contractOp) {
156 VectorType accType = dyn_cast<VectorType>(contractOp.getAcc().getType());
157 if (!accType)
158 return rewriter.notifyMatchFailure(contractOp, "Expects vector acc");
159
160 // Check if operand types are compatible with AMX compute ops.
161 bool validElemTypes = false;
162 Type lhsElemType = contractOp.getLhs().getType().getElementType();
163 Type rhsElemType = contractOp.getRhs().getType().getElementType();
164 Type accElemType = accType.getElementType();
165 if (accElemType.isInteger(32)) {
166 validElemTypes = lhsElemType.isInteger(8) && rhsElemType.isInteger(8);
167 } else if (accElemType.isF32()) {
168 validElemTypes = (lhsElemType.isF16() && rhsElemType.isF16()) ||
169 (lhsElemType.isBF16() && rhsElemType.isBF16());
170 }
171 if (!validElemTypes)
172 return rewriter.notifyMatchFailure(contractOp,
173 "Invalid combination of operand types");
174
175 if (failed(isAmxVnniLayout(rewriter, contractOp)))
176 return failure();
177
178 return success();
179}
180
181/// Collapse the two innermost dimensions together.
182static TypedValue<MemRefType> collapseLastDim(PatternRewriter &rewriter,
184 int64_t rank = memref.getType().getRank();
186 for (auto i : llvm::seq<int64_t>(0, rank - 2))
187 reassocIndices.push_back({i});
188 reassocIndices.push_back({rank - 2, rank - 1});
189 return memref::CollapseShapeOp::create(rewriter, memref.getLoc(), memref,
190 reassocIndices);
191}
192
193/// Attempt to create an AMX tile load/store operation equivalent to the given
194/// vector transfer `xfer` op.
195/// This approach allows to skip longer route through registers and a temporary
196/// buffer otherwise required to move data to/from an AMX tile.
197static Operation *
198loadStoreFromTransfer(PatternRewriter &rewriter,
199 VectorTransferOpInterface xferOp, bool isPacked,
200 TypedValue<amx::TileType> tileToStore = nullptr) {
201 if (!xferOp || !isa<vector::TransferReadOp, vector::TransferWriteOp>(xferOp))
202 return nullptr;
203 if (xferOp.hasOutOfBoundsDim() ||
204 !xferOp.getPermutationMap().isMinorIdentity())
205 return nullptr;
206
207 // Extra checks in case of a write op.
208 // Stores must not be packed.
209 if (isa<vector::TransferWriteOp>(xferOp) &&
210 (!tileToStore || isPacked ||
211 tileToStore.getType().getShape() != xferOp.getVectorType().getShape()))
212 return nullptr;
213
214 // Check for a memref source buffer.
215 // AMX data transfer requires at least 2D shape to correctly
216 // infer stride between rows.
217 Value base = xferOp.getBase();
218 auto memTy = dyn_cast<MemRefType>(base.getType());
219 int64_t memRank = memTy.getRank();
220 if (!memTy || memRank < 2)
221 return nullptr;
222
223 // Check that the source buffer has enough contiguous elements to load whole
224 // AMX tile row.
225 //
226 // To ensure correctness, the validation is conservative and expects the
227 // buffer's innermost dimensions to be statically known, equal to or larger
228 // than the vector row length, and equal to the VNNI dimension if applicable.
229 //
230 // This check could be relaxed to accept more arbitrarily shaped buffers as
231 // long as there are enough contiguous elements to load a whole row.
232 if (!memTy.areTrailingDimsContiguous(isPacked ? 2 : 1))
233 return nullptr;
234 VectorType vecTy = xferOp.getVectorType();
235 ArrayRef<int64_t> vecShape = vecTy.getShape();
236 ArrayRef<int64_t> memShape = memTy.getShape();
237 if (memShape.back() == ShapedType::kDynamic ||
238 memShape.back() < vecShape.back())
239 return nullptr;
240 if (isPacked &&
241 (memShape.back() != vecShape.back() ||
242 memShape[memShape.size() - 2] == ShapedType::kDynamic ||
243 memShape[memShape.size() - 2] < vecShape[vecShape.size() - 2]))
244 return nullptr;
245
246 // Load values directly from the buffer to an AMX tile.
248 rewriter.setInsertionPoint(xferOp);
249 Location loc = xferOp.getLoc();
250
251 // Create a subview of the source buffer based on the transfer op to resolve
252 // offsets.
253 SmallVector<OpFoldResult> strides(memRank, rewriter.getIndexAttr(1));
254 int64_t vecRank = vecTy.getRank();
255 assert(memRank >= vecRank &&
256 "Expects buffer to be the same or greater rank than vector");
257 SmallVector<int64_t> shape(memRank - vecRank, 1);
258 shape.append(vecShape.begin(), vecShape.end());
260 memref::SubViewOp::create(
261 rewriter, loc, base, getAsOpFoldResult(xferOp.getIndices()),
262 getAsOpFoldResult(rewriter.getI64ArrayAttr(shape)), strides)
263 .getResult();
264
265 // Collapse the VNNI dimension in case of packing.
266 if (isPacked)
267 src = collapseLastDim(rewriter, src);
268 int64_t rows = vecShape[0];
269 int64_t cols = llvm::product_of(vecShape.drop_front());
270 auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType());
271
272 Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
273 SmallVector<Value> tileIndicides(src.getType().getRank(), zeroIndex);
274
275 Operation *amxTileOp = nullptr;
276 if (isa<vector::TransferReadOp>(xferOp)) {
277 amxTileOp =
278 amx::TileLoadOp::create(rewriter, loc, tileType, src, tileIndicides);
279 } else if (isa<vector::TransferWriteOp>(xferOp)) {
280 amxTileOp = amx::TileStoreOp::create(rewriter, loc, src, tileIndicides,
281 tileToStore);
282 } else {
283 llvm_unreachable("unsupported vector transfer op");
284 }
285
286 return amxTileOp;
287}
288
289/// Attempt to create an AMX tile load operation equivalent to the given
290/// vector transfer `readOp`.
291/// Returns loaded AMX tile if successful.
292static FailureOr<TypedValue<amx::TileType>>
293loadFromTransfer(PatternRewriter &rewriter, vector::TransferReadOp readOp,
294 bool isPacked) {
295 amx::TileLoadOp loadOp = dyn_cast_if_present<amx::TileLoadOp>(
296 loadStoreFromTransfer(rewriter, readOp, isPacked));
297 if (!loadOp)
298 return failure();
299 return loadOp.getRes();
300}
301
302/// Attempt to create an AMX tile store operation equivalent to the given
303/// vector transfer `writeOp`.
304static LogicalResult storeFromTransfer(PatternRewriter &rewriter,
305 vector::TransferWriteOp writeOp,
306 TypedValue<amx::TileType> tileToStore) {
307 return success(loadStoreFromTransfer(rewriter, writeOp, /*isPacked=*/false,
308 tileToStore));
309}
310
311/// Load vector values to an AMX tile.
312static TypedValue<amx::TileType> loadTile(PatternRewriter &rewriter,
314 Location loc = vec.getLoc();
315
316 VectorType vecTy = vec.getType();
317 bool isPacked = vecTy.getRank() == 3;
318
319 // Try to load tile directly from vector producer's buffer.
320 auto readOp = vec.getDefiningOp<vector::TransferReadOp>();
321 FailureOr<TypedValue<amx::TileType>> tile =
322 loadFromTransfer(rewriter, readOp, isPacked);
323 if (succeeded(tile))
324 return *tile;
325
326 // Transfer the vector to a tile through an intermediate buffer.
327 Value buf = memref::AllocaOp::create(
328 rewriter, loc, MemRefType::get(vecTy.getShape(), vecTy.getElementType()));
329 Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
330 SmallVector<Value> indices(vecTy.getRank(), zeroIndex);
331 vector::TransferWriteOp::create(rewriter, loc, vec, buf, indices);
332
333 // Collapse the VNNI dimension in case of packing.
334 if (isPacked)
335 buf = collapseLastDim(rewriter, cast<TypedValue<MemRefType>>(buf));
336
337 ArrayRef<int64_t> shape = vecTy.getShape();
338 int64_t rows = shape[0];
339 int64_t cols = llvm::product_of(shape.drop_front());
340 auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType());
341
342 return amx::TileLoadOp::create(rewriter, loc, tileType, buf,
343 {zeroIndex, zeroIndex});
344}
345
346/// Store an AMX tile in a vector.
347static TypedValue<VectorType> storeTile(PatternRewriter &rewriter,
349 Location loc = tile.getLoc();
350
351 // Transfer the tile to a vector through an intermediate buffer.
352 amx::TileType tileTy = tile.getType();
353 Value buf = memref::AllocaOp::create(
354 rewriter, loc,
355 MemRefType::get(tileTy.getShape(), tileTy.getElementType()));
356 Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
357 SmallVector<Value> indices(2, zeroIndex);
358 amx::TileStoreOp::create(rewriter, loc, buf, indices, tile);
359
360 auto vecTy = VectorType::get(tileTy.getShape(), tileTy.getElementType());
361 return vector::TransferReadOp::create(rewriter, loc, vecTy, buf, indices, {});
362}
363
364struct ContractionToAMX : public OpRewritePattern<vector::ContractionOp> {
365 using Base::Base;
366
367 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
368 PatternRewriter &rewriter) const override {
369 Location loc = contractOp.getLoc();
370
371 if (contractOp.getKind() != vector::CombiningKind::ADD)
372 return rewriter.notifyMatchFailure(contractOp,
373 "Expects add combining kind");
374 if (failed(validateOperands(rewriter, contractOp)))
375 return failure();
376
377 TypedValue<amx::TileType> lhsTile = loadTile(rewriter, contractOp.getLhs());
378 TypedValue<amx::TileType> rhsTile = loadTile(rewriter, contractOp.getRhs());
379 auto acc = dyn_cast<TypedValue<VectorType>>(contractOp.getAcc());
380 assert(acc && "Invalid accumulator type");
381 TypedValue<amx::TileType> accTile = loadTile(rewriter, acc);
382
384 if (acc.getType().getElementType().isFloat()) {
385 tileMul = amx::TileMulFOp::create(rewriter, loc, accTile.getType(),
386 lhsTile, rhsTile, accTile);
387 } else {
388 tileMul = amx::TileMulIOp::create(rewriter, loc, accTile.getType(),
389 lhsTile, rhsTile, accTile);
390 }
391
392 // If the contraction result is only written back to memory, try to replace
393 // the vector op with an AMX store directly.
394 Value res = contractOp.getResult();
395 if (res.hasOneUse()) {
396 auto writeOp = dyn_cast<vector::TransferWriteOp>(*res.getUsers().begin());
397 LogicalResult storeRes = storeFromTransfer(rewriter, writeOp, tileMul);
398 if (succeeded(storeRes)) {
399 rewriter.eraseOp(writeOp);
400 rewriter.eraseOp(contractOp);
401 return success();
402 }
403 }
404
405 // Load the result back into a vector.
406 Value newResult = storeTile(rewriter, tileMul);
407 rewriter.replaceOp(contractOp, newResult);
408
409 return success();
410 }
411};
412
413struct ConvertVectorToAMXPass
414 : public impl::ConvertVectorToAMXBase<ConvertVectorToAMXPass> {
415 void runOnOperation() override {
416 MLIRContext &ctx = getContext();
419 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
420 return signalPassFailure();
421 }
422};
423
424} // namespace
425
427 patterns.add<ContractionToAMX>(patterns.getContext());
428}
return success()
b getContext())
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
unsigned getNumResults() const
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:281
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF32() const
Definition Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
bool isF16() const
Definition Types.cpp:38
bool isBF16() const
Definition Types.cpp:37
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
user_range getUsers() const
Definition Value.h:218
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition Value.h:197
Specialization of arith.constant op that returns an integer of index type.
Definition Arith.h:113
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
Include the generated interface declarations.
void populateVectorToAMXConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert from the vector to AMX ops.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
const FrozenRewritePatternSet & patterns
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition Utils.cpp:1293
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...