23 #include "llvm/Support/DebugLog.h"
28 #define GEN_PASS_DEF_CONVERTVECTORTOAMX
29 #include "mlir/Conversion/Passes.h.inc"
34 #define DEBUG_TYPE "vector-to-amx"
40 static bool verifyAmxShape(VectorType vec) {
44 if (vec.getRank() != 2 && vec.getRank() != 3)
48 int64_t
rows = shape[0];
49 int64_t
cols = shape[1];
50 unsigned elemBitWidth = vec.getElementType().getIntOrFloatBitWidth();
53 if (vec.getRank() == 3) {
54 int64_t vnniFactor = 32 / elemBitWidth;
55 if (shape.back() != vnniFactor) {
56 LDBG() <<
"invalid VNNI packing factor";
63 constexpr
unsigned maxRows = 16;
64 constexpr
unsigned maxBitsPerRow = 64 * 8;
65 return rows <= maxRows && (
cols * elemBitWidth) <= maxBitsPerRow;
70 vector::ContractionOp contractOp) {
71 VectorType accType = dyn_cast<VectorType>(contractOp.getAcc().getType());
72 if (!accType || accType.getRank() != 2)
76 VectorType lhsType = contractOp.getLhs().getType();
77 VectorType rhsType = contractOp.getRhs().getType();
78 if (lhsType.getRank() != 3 || rhsType.getRank() != 3)
80 "Expects lhs and rhs 3D vectors");
83 if (!verifyAmxShape(lhsType) || !verifyAmxShape(rhsType) ||
84 !verifyAmxShape(accType))
101 "Invalid input indexing maps");
102 FailureOr<linalg::ContractionDimensions> dims =
106 "Failed to infer contraction dims");
110 if (dims->k.size() != 2)
112 "Expected two reduction dims");
113 assert(dims->m.size() == 1 && dims->n.size() == 1 &&
114 "Invalid parallel contraction dims");
117 contractOp.getIteratorTypesArray();
119 auto vnniDimA = dyn_cast<AffineDimExpr>(mapA.
getResult(2));
120 auto vnniDimB = dyn_cast<AffineDimExpr>(mapB.
getResult(2));
121 if (!vnniDimA || !vnniDimB || vnniDimA != vnniDimB ||
122 iteratorTypes[vnniDimA.getPosition()] != vector::IteratorType::reduction)
125 auto redDimA = dyn_cast<AffineDimExpr>(mapA.
getResult(1));
126 auto redDimB = dyn_cast<AffineDimExpr>(mapB.
getResult(0));
127 if (!redDimA || !redDimB || redDimA != redDimB ||
128 iteratorTypes[redDimA.getPosition()] != vector::IteratorType::reduction)
132 auto mDimC = dyn_cast<AffineDimExpr>(mapC.
getResult(0));
133 auto nDimC = dyn_cast<AffineDimExpr>(mapC.
getResult(1));
134 if (!mDimC || !nDimC)
136 auto parallelDimA = dyn_cast<AffineDimExpr>(mapA.
getResult(0));
138 iteratorTypes[parallelDimA.getPosition()] !=
139 vector::IteratorType::parallel ||
140 parallelDimA != mDimC)
142 auto parallelDimB = dyn_cast<AffineDimExpr>(mapB.
getResult(1));
144 iteratorTypes[parallelDimB.getPosition()] !=
145 vector::IteratorType::parallel ||
146 parallelDimB != nDimC)
154 vector::ContractionOp contractOp) {
155 VectorType accType = dyn_cast<VectorType>(contractOp.getAcc().getType());
160 bool validElemTypes =
false;
161 Type lhsElemType = contractOp.getLhs().getType().getElementType();
162 Type rhsElemType = contractOp.getRhs().getType().getElementType();
163 Type accElemType = accType.getElementType();
166 }
else if (accElemType.
isF32()) {
167 validElemTypes = (lhsElemType.
isF16() && rhsElemType.
isF16()) ||
172 "Invalid combination of operand types");
174 if (
failed(isAmxVnniLayout(rewriter, contractOp)))
183 int64_t rank = memref.getType().getRank();
185 for (
auto i : llvm::seq<int64_t>(0, rank - 2))
186 reassocIndices.push_back({i});
187 reassocIndices.push_back({rank - 2, rank - 1});
188 return memref::CollapseShapeOp::create(rewriter, memref.getLoc(), memref,
198 VectorTransferOpInterface xferOp,
bool isPacked,
200 if (!xferOp || !isa<vector::TransferReadOp, vector::TransferWriteOp>(xferOp))
202 if (xferOp.hasOutOfBoundsDim() ||
203 !xferOp.getPermutationMap().isMinorIdentity())
208 if (isa<vector::TransferWriteOp>(xferOp) &&
209 (!tileToStore || isPacked ||
210 tileToStore.getType().getShape() != xferOp.getVectorType().getShape()))
216 Value base = xferOp.getBase();
217 auto memTy = dyn_cast<MemRefType>(base.
getType());
218 int64_t memRank = memTy.getRank();
219 if (!memTy || memRank < 2)
231 if (!memTy.areTrailingDimsContiguous(isPacked ? 2 : 1))
233 VectorType vecTy = xferOp.getVectorType();
236 if (memShape.back() == ShapedType::kDynamic ||
237 memShape.back() < vecShape.back())
240 (memShape.back() != vecShape.back() ||
241 memShape[memShape.size() - 2] == ShapedType::kDynamic ||
242 memShape[memShape.size() - 2] < vecShape[vecShape.size() - 2]))
246 PatternRewriter::InsertionGuard g(rewriter);
253 int64_t vecRank = vecTy.getRank();
254 assert(memRank >= vecRank &&
255 "Expects buffer to be the same or greater rank than vector");
257 shape.append(vecShape.begin(), vecShape.end());
259 memref::SubViewOp::create(
266 src = collapseLastDim(rewriter, src);
267 int64_t
rows = vecShape[0];
268 int64_t
cols = std::accumulate(vecShape.begin() + 1, vecShape.end(), 1,
269 std::multiplies<int64_t>());
276 if (isa<vector::TransferReadOp>(xferOp)) {
278 amx::TileLoadOp::create(rewriter, loc, tileType, src, tileIndicides);
279 }
else if (isa<vector::TransferWriteOp>(xferOp)) {
280 amxTileOp = amx::TileStoreOp::create(rewriter, loc, src, tileIndicides,
283 llvm_unreachable(
"unsupported vector transfer op");
292 static FailureOr<TypedValue<amx::TileType>>
293 loadFromTransfer(
PatternRewriter &rewriter, vector::TransferReadOp readOp,
295 amx::TileLoadOp loadOp = dyn_cast_if_present<amx::TileLoadOp>(
296 loadStoreFromTransfer(rewriter, readOp, isPacked));
299 return loadOp.getRes();
305 vector::TransferWriteOp writeOp,
307 return success(loadStoreFromTransfer(rewriter, writeOp,
false,
316 VectorType vecTy = vec.getType();
317 bool isPacked = vecTy.getRank() == 3;
320 auto readOp = vec.getDefiningOp<vector::TransferReadOp>();
321 FailureOr<TypedValue<amx::TileType>>
tile =
322 loadFromTransfer(rewriter, readOp, isPacked);
327 Value buf = memref::AllocaOp::create(
328 rewriter, loc,
MemRefType::get(vecTy.getShape(), vecTy.getElementType()));
331 vector::TransferWriteOp::create(rewriter, loc, vec, buf, indices);
338 int64_t
rows = shape[0];
339 int64_t
cols = std::accumulate(shape.begin() + 1, shape.end(), 1,
340 std::multiplies<int64_t>());
343 return amx::TileLoadOp::create(rewriter, loc, tileType, buf,
344 {zeroIndex, zeroIndex});
353 amx::TileType tileTy =
tile.getType();
354 Value buf = memref::AllocaOp::create(
359 amx::TileStoreOp::create(rewriter, loc, buf, indices,
tile);
361 auto vecTy =
VectorType::get(tileTy.getShape(), tileTy.getElementType());
362 return vector::TransferReadOp::create(rewriter, loc, vecTy, buf, indices, {});
368 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
372 if (contractOp.getKind() != vector::CombiningKind::ADD)
374 "Expects add combining kind");
375 if (
failed(validateOperands(rewriter, contractOp)))
380 auto acc = dyn_cast<TypedValue<VectorType>>(contractOp.getAcc());
381 assert(acc &&
"Invalid accumulator type");
385 if (acc.getType().getElementType().isFloat()) {
386 tileMul = amx::TileMulFOp::create(rewriter, loc, accTile.getType(),
387 lhsTile, rhsTile, accTile);
389 tileMul = amx::TileMulIOp::create(rewriter, loc, accTile.getType(),
390 lhsTile, rhsTile, accTile);
395 Value res = contractOp.getResult();
397 auto writeOp = dyn_cast<vector::TransferWriteOp>(*res.
getUsers().begin());
398 LogicalResult storeRes = storeFromTransfer(rewriter, writeOp, tileMul);
399 if (succeeded(storeRes)) {
407 Value newResult = storeTile(rewriter, tileMul);
408 rewriter.
replaceOp(contractOp, newResult);
414 struct ConvertVectorToAMXPass
415 :
public impl::ConvertVectorToAMXBase<ConvertVectorToAMXPass> {
416 void runOnOperation()
override {
421 return signalPassFailure();
static MLIRContext * getContext(OpFoldResult val)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getNumResults() const
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
IntegerAttr getIndexAttr(int64_t value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
bool hasOneUse() const
Returns true if this value has exactly one use.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
Include the generated interface declarations.
void populateVectorToAMXConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert from the vector to AMX ops.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...