23#include "llvm/ADT/STLExtras.h"
24#include "llvm/Support/DebugLog.h"
29#define GEN_PASS_DEF_CONVERTVECTORTOAMX
30#include "mlir/Conversion/Passes.h.inc"
35#define DEBUG_TYPE "vector-to-amx"
41static bool verifyAmxShape(VectorType vec) {
45 if (vec.getRank() != 2 && vec.getRank() != 3)
51 unsigned elemBitWidth = vec.getElementType().getIntOrFloatBitWidth();
54 if (vec.getRank() == 3) {
55 int64_t vnniFactor = 32 / elemBitWidth;
56 if (
shape.back() != vnniFactor) {
57 LDBG() <<
"invalid VNNI packing factor";
64 constexpr unsigned maxRows = 16;
65 constexpr unsigned maxBitsPerRow = 64 * 8;
66 return rows <= maxRows && (cols * elemBitWidth) <= maxBitsPerRow;
71 vector::ContractionOp contractOp) {
72 VectorType accType = dyn_cast<VectorType>(contractOp.getAcc().getType());
73 if (!accType || accType.getRank() != 2)
77 VectorType lhsType = contractOp.getLhs().getType();
78 VectorType rhsType = contractOp.getRhs().getType();
79 if (lhsType.getRank() != 3 || rhsType.getRank() != 3)
81 "Expects lhs and rhs 3D vectors");
84 if (!verifyAmxShape(lhsType) || !verifyAmxShape(rhsType) ||
85 !verifyAmxShape(accType))
102 "Invalid input indexing maps");
103 FailureOr<linalg::ContractionDimensions> dims =
107 "Failed to infer contraction dims");
111 if (dims->k.size() != 2)
113 "Expected two reduction dims");
114 assert(dims->m.size() == 1 && dims->n.size() == 1 &&
115 "Invalid parallel contraction dims");
118 contractOp.getIteratorTypesArray();
120 auto vnniDimA = dyn_cast<AffineDimExpr>(mapA.
getResult(2));
121 auto vnniDimB = dyn_cast<AffineDimExpr>(mapB.
getResult(2));
122 if (!vnniDimA || !vnniDimB || vnniDimA != vnniDimB ||
123 iteratorTypes[vnniDimA.getPosition()] != vector::IteratorType::reduction)
126 auto redDimA = dyn_cast<AffineDimExpr>(mapA.
getResult(1));
127 auto redDimB = dyn_cast<AffineDimExpr>(mapB.
getResult(0));
128 if (!redDimA || !redDimB || redDimA != redDimB ||
129 iteratorTypes[redDimA.getPosition()] != vector::IteratorType::reduction)
133 auto mDimC = dyn_cast<AffineDimExpr>(mapC.
getResult(0));
134 auto nDimC = dyn_cast<AffineDimExpr>(mapC.
getResult(1));
135 if (!mDimC || !nDimC)
137 auto parallelDimA = dyn_cast<AffineDimExpr>(mapA.
getResult(0));
139 iteratorTypes[parallelDimA.getPosition()] !=
140 vector::IteratorType::parallel ||
141 parallelDimA != mDimC)
143 auto parallelDimB = dyn_cast<AffineDimExpr>(mapB.
getResult(1));
145 iteratorTypes[parallelDimB.getPosition()] !=
146 vector::IteratorType::parallel ||
147 parallelDimB != nDimC)
155 vector::ContractionOp contractOp) {
156 VectorType accType = dyn_cast<VectorType>(contractOp.getAcc().getType());
161 bool validElemTypes =
false;
162 Type lhsElemType = contractOp.getLhs().getType().getElementType();
163 Type rhsElemType = contractOp.getRhs().getType().getElementType();
164 Type accElemType = accType.getElementType();
167 }
else if (accElemType.
isF32()) {
168 validElemTypes = (lhsElemType.
isF16() && rhsElemType.
isF16()) ||
173 "Invalid combination of operand types");
175 if (failed(isAmxVnniLayout(rewriter, contractOp)))
186 for (
auto i : llvm::seq<int64_t>(0, rank - 2))
187 reassocIndices.push_back({i});
188 reassocIndices.push_back({rank - 2, rank - 1});
189 return memref::CollapseShapeOp::create(rewriter,
memref.getLoc(),
memref,
199 VectorTransferOpInterface xferOp,
bool isPacked,
201 if (!xferOp || !isa<vector::TransferReadOp, vector::TransferWriteOp>(xferOp))
203 if (xferOp.hasOutOfBoundsDim() ||
204 !xferOp.getPermutationMap().isMinorIdentity())
209 if (isa<vector::TransferWriteOp>(xferOp) &&
210 (!tileToStore || isPacked ||
211 tileToStore.getType().getShape() != xferOp.getVectorType().getShape()))
217 Value base = xferOp.getBase();
218 auto memTy = dyn_cast<MemRefType>(base.
getType());
219 int64_t memRank = memTy.getRank();
220 if (!memTy || memRank < 2)
232 if (!memTy.areTrailingDimsContiguous(isPacked ? 2 : 1))
234 VectorType vecTy = xferOp.getVectorType();
237 if (memShape.back() == ShapedType::kDynamic ||
238 memShape.back() < vecShape.back())
241 (memShape.back() != vecShape.back() ||
242 memShape[memShape.size() - 2] == ShapedType::kDynamic ||
243 memShape[memShape.size() - 2] < vecShape[vecShape.size() - 2]))
254 int64_t vecRank = vecTy.getRank();
255 assert(memRank >= vecRank &&
256 "Expects buffer to be the same or greater rank than vector");
258 shape.append(vecShape.begin(), vecShape.end());
260 memref::SubViewOp::create(
267 src = collapseLastDim(rewriter, src);
269 int64_t cols = llvm::product_of(vecShape.drop_front());
270 auto tileType = x86::amx::TileType::get({rows, cols}, vecTy.getElementType());
276 if (isa<vector::TransferReadOp>(xferOp)) {
277 amxTileOp = x86::amx::TileLoadOp::create(rewriter, loc, tileType, src,
279 }
else if (isa<vector::TransferWriteOp>(xferOp)) {
280 amxTileOp = x86::amx::TileStoreOp::create(rewriter, loc, src, tileIndicides,
283 llvm_unreachable(
"unsupported vector transfer op");
292static FailureOr<TypedValue<x86::amx::TileType>>
293loadFromTransfer(
PatternRewriter &rewriter, vector::TransferReadOp readOp,
295 x86::amx::TileLoadOp loadOp = dyn_cast_if_present<x86::amx::TileLoadOp>(
296 loadStoreFromTransfer(rewriter, readOp, isPacked));
299 return loadOp.getRes();
305storeFromTransfer(
PatternRewriter &rewriter, vector::TransferWriteOp writeOp,
307 return success(loadStoreFromTransfer(rewriter, writeOp,
false,
316 VectorType vecTy = vec.getType();
317 bool isPacked = vecTy.getRank() == 3;
320 auto readOp = vec.getDefiningOp<vector::TransferReadOp>();
321 FailureOr<TypedValue<x86::amx::TileType>>
tile =
322 loadFromTransfer(rewriter, readOp, isPacked);
327 Value buf = memref::AllocaOp::create(
328 rewriter, loc, MemRefType::get(vecTy.getShape(), vecTy.getElementType()));
331 vector::TransferWriteOp::create(rewriter, loc, vec, buf,
indices);
340 auto tileType = x86::amx::TileType::get({rows, cols}, vecTy.getElementType());
342 return x86::amx::TileLoadOp::create(rewriter, loc, tileType, buf,
343 {zeroIndex, zeroIndex});
353 Value buf = memref::AllocaOp::create(
355 MemRefType::get(tileTy.getShape(), tileTy.getElementType()));
358 x86::amx::TileStoreOp::create(rewriter, loc, buf,
indices,
tile);
360 auto vecTy = VectorType::get(tileTy.getShape(), tileTy.getElementType());
361 return vector::TransferReadOp::create(rewriter, loc, vecTy, buf,
indices, {});
367 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
371 if (contractOp.getKind() != vector::CombiningKind::ADD)
373 "Expects add combining kind");
374 if (failed(validateOperands(rewriter, contractOp)))
378 loadTile(rewriter, contractOp.getLhs());
380 loadTile(rewriter, contractOp.getRhs());
381 auto acc = dyn_cast<TypedValue<VectorType>>(contractOp.getAcc());
382 assert(
acc &&
"Invalid accumulator type");
386 if (
acc.getType().getElementType().isFloat()) {
387 tileMul = x86::amx::TileMulFOp::create(rewriter, loc, accTile.getType(),
388 lhsTile, rhsTile, accTile);
390 tileMul = x86::amx::TileMulIOp::create(rewriter, loc, accTile.getType(),
391 lhsTile, rhsTile, accTile);
396 Value res = contractOp.getResult();
398 auto writeOp = dyn_cast<vector::TransferWriteOp>(*res.
getUsers().begin());
399 LogicalResult storeRes = storeFromTransfer(rewriter, writeOp, tileMul);
400 if (succeeded(storeRes)) {
408 Value newResult = storeTile(rewriter, tileMul);
409 rewriter.
replaceOp(contractOp, newResult);
415struct ConvertVectorToAMXPass
417 void runOnOperation()
override {
422 return signalPassFailure();
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getNumResults() const
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
IntegerAttr getIndexAttr(int64_t value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
bool hasOneUse() const
Returns true if this value has exactly one use.
Specialization of arith.constant op that returns an integer of index type.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
mlir::x86::AMXTileType TileType
Include the generated interface declarations.
void populateVectorToAMXConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert from the vector to X86 AMX ops.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...