23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/Support/DebugLog.h"
29 #define GEN_PASS_DEF_CONVERTVECTORTOAMX
30 #include "mlir/Conversion/Passes.h.inc"
35 #define DEBUG_TYPE "vector-to-amx"
41 static bool verifyAmxShape(VectorType vec) {
45 if (vec.getRank() != 2 && vec.getRank() != 3)
49 int64_t
rows = shape[0];
50 int64_t
cols = shape[1];
51 unsigned elemBitWidth = vec.getElementType().getIntOrFloatBitWidth();
54 if (vec.getRank() == 3) {
55 int64_t vnniFactor = 32 / elemBitWidth;
56 if (shape.back() != vnniFactor) {
57 LDBG() <<
"invalid VNNI packing factor";
64 constexpr
unsigned maxRows = 16;
65 constexpr
unsigned maxBitsPerRow = 64 * 8;
66 return rows <= maxRows && (
cols * elemBitWidth) <= maxBitsPerRow;
71 vector::ContractionOp contractOp) {
72 VectorType accType = dyn_cast<VectorType>(contractOp.getAcc().getType());
73 if (!accType || accType.getRank() != 2)
77 VectorType lhsType = contractOp.getLhs().getType();
78 VectorType rhsType = contractOp.getRhs().getType();
79 if (lhsType.getRank() != 3 || rhsType.getRank() != 3)
81 "Expects lhs and rhs 3D vectors");
84 if (!verifyAmxShape(lhsType) || !verifyAmxShape(rhsType) ||
85 !verifyAmxShape(accType))
102 "Invalid input indexing maps");
103 FailureOr<linalg::ContractionDimensions> dims =
107 "Failed to infer contraction dims");
111 if (dims->k.size() != 2)
113 "Expected two reduction dims");
114 assert(dims->m.size() == 1 && dims->n.size() == 1 &&
115 "Invalid parallel contraction dims");
118 contractOp.getIteratorTypesArray();
120 auto vnniDimA = dyn_cast<AffineDimExpr>(mapA.
getResult(2));
121 auto vnniDimB = dyn_cast<AffineDimExpr>(mapB.
getResult(2));
122 if (!vnniDimA || !vnniDimB || vnniDimA != vnniDimB ||
123 iteratorTypes[vnniDimA.getPosition()] != vector::IteratorType::reduction)
126 auto redDimA = dyn_cast<AffineDimExpr>(mapA.
getResult(1));
127 auto redDimB = dyn_cast<AffineDimExpr>(mapB.
getResult(0));
128 if (!redDimA || !redDimB || redDimA != redDimB ||
129 iteratorTypes[redDimA.getPosition()] != vector::IteratorType::reduction)
133 auto mDimC = dyn_cast<AffineDimExpr>(mapC.
getResult(0));
134 auto nDimC = dyn_cast<AffineDimExpr>(mapC.
getResult(1));
135 if (!mDimC || !nDimC)
137 auto parallelDimA = dyn_cast<AffineDimExpr>(mapA.
getResult(0));
139 iteratorTypes[parallelDimA.getPosition()] !=
140 vector::IteratorType::parallel ||
141 parallelDimA != mDimC)
143 auto parallelDimB = dyn_cast<AffineDimExpr>(mapB.
getResult(1));
145 iteratorTypes[parallelDimB.getPosition()] !=
146 vector::IteratorType::parallel ||
147 parallelDimB != nDimC)
155 vector::ContractionOp contractOp) {
156 VectorType accType = dyn_cast<VectorType>(contractOp.getAcc().getType());
161 bool validElemTypes =
false;
162 Type lhsElemType = contractOp.getLhs().getType().getElementType();
163 Type rhsElemType = contractOp.getRhs().getType().getElementType();
164 Type accElemType = accType.getElementType();
167 }
else if (accElemType.
isF32()) {
168 validElemTypes = (lhsElemType.
isF16() && rhsElemType.
isF16()) ||
173 "Invalid combination of operand types");
175 if (
failed(isAmxVnniLayout(rewriter, contractOp)))
184 int64_t rank = memref.getType().getRank();
186 for (
auto i : llvm::seq<int64_t>(0, rank - 2))
187 reassocIndices.push_back({i});
188 reassocIndices.push_back({rank - 2, rank - 1});
189 return memref::CollapseShapeOp::create(rewriter, memref.getLoc(), memref,
199 VectorTransferOpInterface xferOp,
bool isPacked,
201 if (!xferOp || !isa<vector::TransferReadOp, vector::TransferWriteOp>(xferOp))
203 if (xferOp.hasOutOfBoundsDim() ||
204 !xferOp.getPermutationMap().isMinorIdentity())
209 if (isa<vector::TransferWriteOp>(xferOp) &&
210 (!tileToStore || isPacked ||
211 tileToStore.getType().getShape() != xferOp.getVectorType().getShape()))
217 Value base = xferOp.getBase();
218 auto memTy = dyn_cast<MemRefType>(base.
getType());
219 int64_t memRank = memTy.getRank();
220 if (!memTy || memRank < 2)
232 if (!memTy.areTrailingDimsContiguous(isPacked ? 2 : 1))
234 VectorType vecTy = xferOp.getVectorType();
237 if (memShape.back() == ShapedType::kDynamic ||
238 memShape.back() < vecShape.back())
241 (memShape.back() != vecShape.back() ||
242 memShape[memShape.size() - 2] == ShapedType::kDynamic ||
243 memShape[memShape.size() - 2] < vecShape[vecShape.size() - 2]))
247 PatternRewriter::InsertionGuard g(rewriter);
254 int64_t vecRank = vecTy.getRank();
255 assert(memRank >= vecRank &&
256 "Expects buffer to be the same or greater rank than vector");
258 shape.append(vecShape.begin(), vecShape.end());
260 memref::SubViewOp::create(
267 src = collapseLastDim(rewriter, src);
268 int64_t
rows = vecShape[0];
269 int64_t
cols = llvm::product_of(vecShape.drop_front());
276 if (isa<vector::TransferReadOp>(xferOp)) {
278 amx::TileLoadOp::create(rewriter, loc, tileType, src, tileIndicides);
279 }
else if (isa<vector::TransferWriteOp>(xferOp)) {
280 amxTileOp = amx::TileStoreOp::create(rewriter, loc, src, tileIndicides,
283 llvm_unreachable(
"unsupported vector transfer op");
292 static FailureOr<TypedValue<amx::TileType>>
293 loadFromTransfer(
PatternRewriter &rewriter, vector::TransferReadOp readOp,
295 amx::TileLoadOp loadOp = dyn_cast_if_present<amx::TileLoadOp>(
296 loadStoreFromTransfer(rewriter, readOp, isPacked));
299 return loadOp.getRes();
305 vector::TransferWriteOp writeOp,
307 return success(loadStoreFromTransfer(rewriter, writeOp,
false,
316 VectorType vecTy = vec.getType();
317 bool isPacked = vecTy.getRank() == 3;
320 auto readOp = vec.getDefiningOp<vector::TransferReadOp>();
321 FailureOr<TypedValue<amx::TileType>>
tile =
322 loadFromTransfer(rewriter, readOp, isPacked);
327 Value buf = memref::AllocaOp::create(
328 rewriter, loc,
MemRefType::get(vecTy.getShape(), vecTy.getElementType()));
331 vector::TransferWriteOp::create(rewriter, loc, vec, buf, indices);
338 int64_t
rows = shape[0];
339 int64_t
cols = llvm::product_of(shape.drop_front());
342 return amx::TileLoadOp::create(rewriter, loc, tileType, buf,
343 {zeroIndex, zeroIndex});
352 amx::TileType tileTy =
tile.getType();
353 Value buf = memref::AllocaOp::create(
358 amx::TileStoreOp::create(rewriter, loc, buf, indices,
tile);
360 auto vecTy =
VectorType::get(tileTy.getShape(), tileTy.getElementType());
361 return vector::TransferReadOp::create(rewriter, loc, vecTy, buf, indices, {});
367 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
371 if (contractOp.getKind() != vector::CombiningKind::ADD)
373 "Expects add combining kind");
374 if (
failed(validateOperands(rewriter, contractOp)))
379 auto acc = dyn_cast<TypedValue<VectorType>>(contractOp.getAcc());
380 assert(acc &&
"Invalid accumulator type");
384 if (acc.getType().getElementType().isFloat()) {
385 tileMul = amx::TileMulFOp::create(rewriter, loc, accTile.getType(),
386 lhsTile, rhsTile, accTile);
388 tileMul = amx::TileMulIOp::create(rewriter, loc, accTile.getType(),
389 lhsTile, rhsTile, accTile);
394 Value res = contractOp.getResult();
396 auto writeOp = dyn_cast<vector::TransferWriteOp>(*res.
getUsers().begin());
397 LogicalResult storeRes = storeFromTransfer(rewriter, writeOp, tileMul);
398 if (succeeded(storeRes)) {
406 Value newResult = storeTile(rewriter, tileMul);
407 rewriter.
replaceOp(contractOp, newResult);
413 struct ConvertVectorToAMXPass
414 :
public impl::ConvertVectorToAMXBase<ConvertVectorToAMXPass> {
415 void runOnOperation()
override {
420 return signalPassFailure();
static MLIRContext * getContext(OpFoldResult val)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getNumResults() const
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
IntegerAttr getIndexAttr(int64_t value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
bool hasOneUse() const
Returns true if this value has exactly one use.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
Include the generated interface declarations.
void populateVectorToAMXConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert from the vector to AMX ops.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...