MLIR 23.0.0git
VectorToGPU.cpp
Go to the documentation of this file.
1//===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements lowering of vector operations to GPU dialect ops.
10//
11//===----------------------------------------------------------------------===//
12
14
28#include "mlir/IR/Builders.h"
29#include "mlir/IR/Region.h"
30#include "mlir/Pass/Pass.h"
32#include "llvm/ADT/STLExtras.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/Support/DebugLog.h"
35
36#define DEBUG_TYPE "vector-to-gpu"
37
38namespace mlir {
39#define GEN_PASS_DEF_CONVERTVECTORTOGPU
40#include "mlir/Conversion/Passes.h.inc"
41} // namespace mlir
42
43using namespace mlir;
44
45/// For a vector TransferOpType `xferOp`, an empty `indices` vector, and an
46/// AffineMap representing offsets to apply to indices, the function fills
47/// `indices` with the original indices plus the offsets. The offsets are
48/// applied by taking into account the permutation map of the transfer op. If
49/// the `offsetMap` has dimension placeholders, those should be provided in
50/// `dimValues`.
51template <typename TransferOpType>
52static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp,
53 AffineMap offsetMap, ArrayRef<Value> dimValues,
55 indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end());
56 Location loc = xferOp.getLoc();
57 unsigned offsetsIdx = 0;
58 for (auto expr : xferOp.getPermutationMap().getResults()) {
59 if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
60 Value prevIdx = indices[dim.getPosition()];
61 SmallVector<OpFoldResult, 3> dims(dimValues);
62 dims.push_back(prevIdx);
63 AffineExpr d0 = rewriter.getAffineDimExpr(offsetMap.getNumDims());
64 indices[dim.getPosition()] = affine::makeComposedAffineApply(
65 rewriter, loc, d0 + offsetMap.getResult(offsetsIdx++), dims);
66 continue;
67 }
68 }
69}
70
71// Return true if the contract op can be convert to MMA matmul.
72static bool contractSupportsMMAMatrixType(vector::ContractionOp contract,
73 bool useNvGpu) {
74 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
75 auto infer = [&](MapList m) {
76 return AffineMap::inferFromExprList(m, contract.getContext());
77 };
78 AffineExpr m, n, k;
79 bindDims(contract.getContext(), m, n, k);
80 auto iteratorTypes = contract.getIteratorTypes().getValue();
81 if (!(vector::isParallelIterator(iteratorTypes[0]) &&
82 vector::isParallelIterator(iteratorTypes[1]) &&
83 vector::isReductionIterator(iteratorTypes[2])))
84 return false;
85
86 // The contract needs to represent a matmul to be able to convert to
87 // MMAMatrix matmul.
88 if (!useNvGpu &&
89 contract.getIndexingMapsArray() != infer({{m, k}, {k, n}, {m, n}}))
90 return false;
91 if (useNvGpu &&
92 contract.getIndexingMapsArray() != infer({{m, k}, {n, k}, {m, n}}))
93 return false;
94
95 return true;
96}
97
98// Test whether the permutation map's first result corresponds to its last
99// dimension.
100//
101// In contexts where we only accept maps that have the last (most minor)
102// dimension as exactly one of the two results, this is sufficient to classify
103// whether it represents a transpose.
104static bool isFirstResultLastMapDimension(AffineMap permutationMap) {
105 MLIRContext *ctx = permutationMap.getContext();
106 const unsigned nDim = permutationMap.getNumDims();
107 if (0 == nDim || permutationMap.getResults().empty())
108 return false;
109 return permutationMap.getResult(0) == getAffineDimExpr(nDim - 1, ctx);
110}
111
112// Return the `leadDimension` (row stride) implied by |permutationMap| for
113// |type|, if |type| is a memref with a statically-known layout.
114//
115// The `leadDimension` is the stride (in elements) between consecutive rows in
116// the 2D view described by |permutationMap|. This helper supports the subset
117// of maps permitted by vector.transfer_read:
118// - Exactly 2 results.
119// - Each result is either an affine dimension or the constant 0 (broadcast).
120//
121// Constraints:
122// - Requires the most minor memref stride to be 1.
123//
124// Broadcast:
125// - If either result is constant 0, the implied `leadDimension` is 0.
126static std::optional<int64_t>
127getStaticallyKnownRowStride(ShapedType type, AffineMap permutationMap) {
128 auto memrefType = dyn_cast<MemRefType>(type);
129 if (!memrefType)
130 return std::nullopt;
131 // If the memref is 0 or 1D the horizontal stride is 0.
132 if (memrefType.getRank() < 2)
133 return 0;
134 int64_t offset = 0;
135 SmallVector<int64_t> strides;
136 if (failed(memrefType.getStridesAndOffset(strides, offset)) ||
137 strides.back() != 1)
138 return std::nullopt;
139
140 if (permutationMap.getNumResults() != 2)
141 return std::nullopt;
142
143 unsigned strideIndex = strides.size();
144
145 for (AffineExpr result : permutationMap.getResults()) {
146 if (auto cst = dyn_cast<AffineConstantExpr>(result)) {
147 // Constant value must be zero.
148 if (0 != cst.getValue())
149 return std::nullopt;
150 // A broadcast result forces row stride to 0.
151 return 0;
152 }
153 auto dim = dyn_cast<AffineDimExpr>(result);
154 // Only Dim & Const results are supported.
155 if (!dim)
156 return std::nullopt;
157 strideIndex = std::min(strideIndex, dim.getPosition());
158 }
159
160 // Structural validity check: ensure that the map selects at least one
161 // dimension more major than the most minor dimension. This also excludes
162 // degenerate cases where both results map to the most minor dimension.
163 if (strideIndex + 1 >= strides.size())
164 return std::nullopt;
165
166 const int64_t stride = strides[strideIndex];
167 if (stride == ShapedType::kDynamic)
168 return std::nullopt;
169 return stride;
170}
171
172// Return true if the transfer op can be converted to a MMA matrix load.
173static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
174 if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
175 readOp.getVectorType().getRank() != 2)
176 return false;
177
178 AffineMap permutationMap = readOp.getPermutationMap();
179 if (!getStaticallyKnownRowStride(readOp.getShapedType(), permutationMap))
180 return false;
181
182 // Only allow integer types if the signedness can be inferred.
183 if (readOp.getVectorType().getElementType().isInteger(8))
184 if (!readOp->hasOneUse() || (!isa<arith::ExtSIOp>(*readOp->user_begin()) &&
185 !isa<arith::ExtUIOp>(*readOp->user_begin())))
186 return false;
187
188 MLIRContext *ctx = readOp.getContext();
189 AffineExpr innerDim = getAffineDimExpr(permutationMap.getNumDims() - 1, ctx);
190 return llvm::is_contained(permutationMap.getResults(), innerDim);
191}
192
193// Return true if the transfer op can be converted to a MMA matrix store.
194static bool
195transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
196 // TODO: support 0-d corner case.
197 if (writeOp.getTransferRank() == 0)
198 return false;
199
200 if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() ||
201 writeOp.getVectorType().getRank() != 2)
202 return false;
203
204 AffineMap permutationMap = writeOp.getPermutationMap();
205 std::optional<int64_t> stride =
206 getStaticallyKnownRowStride(writeOp.getShapedType(), permutationMap);
207 // Stride of zero means broadcast which is not permitted for writes.
208 if (!stride.has_value() || stride.value() == 0)
209 return false;
210
211 MLIRContext *ctx = writeOp.getContext();
212 AffineExpr innerDim = getAffineDimExpr(permutationMap.getNumDims() - 1, ctx);
213 // TODO: Support transpose once it is added to GPU dialect ops.
214 return permutationMap.getResult(1) == innerDim;
215}
216
217/// Return true if the constant is a splat to a 2D vector so that it can be
218/// converted to a MMA constant matrix op.
219static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) {
220 auto vecType = dyn_cast<VectorType>(constantOp.getType());
221 if (!vecType || vecType.getRank() != 2)
222 return false;
223 return isa<SplatElementsAttr>(constantOp.getValue());
224}
225
226/// Return true if this is a broadcast from scalar to a 2D vector.
227static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {
228 return broadcastOp.getResultVectorType().getRank() == 2;
229}
230
231/// Return true if this integer extend op can be folded into a contract op.
232template <typename ExtOpTy>
233static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) {
234 auto transferReadOp =
235 extOp.getOperand().template getDefiningOp<vector::TransferReadOp>();
236 if (!transferReadOp)
237 return false;
238 return llvm::all_of(extOp->getUsers(), llvm::IsaPred<vector::ContractionOp>);
239}
240
241static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp) { return true; }
242static bool fpTruncSupportsMMAMatrixType(arith::TruncFOp extOp) { return true; }
243
244/// Return the MMA elementwise enum associated with `op` if it is supported.
245/// Return `std::nullopt` otherwise.
246static std::optional<gpu::MMAElementwiseOp>
248 using MMAEwO = gpu::MMAElementwiseOp;
250 .Case([](arith::AddFOp) { return MMAEwO::ADDF; })
251 .Case([](arith::AddIOp) { return MMAEwO::ADDI; })
252 .Case([](arith::DivFOp) { return MMAEwO::DIVF; })
253 .Case([](arith::DivSIOp) { return MMAEwO::DIVS; })
254 .Case([](arith::DivUIOp) { return MMAEwO::DIVU; })
255 .Case([](arith::ExtFOp) { return MMAEwO::EXTF; })
256 .Case([](arith::MaximumFOp) { return MMAEwO::MAXF; })
257 .Case([](arith::MinimumFOp) { return MMAEwO::MINF; })
258 .Case([](arith::MulFOp) { return MMAEwO::MULF; })
259 .Case([](arith::MulIOp) { return MMAEwO::MULI; })
260 .Case([](arith::NegFOp) { return MMAEwO::NEGATEF; })
261 .Case([](arith::SubFOp) { return MMAEwO::SUBF; })
262 .Case([](arith::SubIOp) { return MMAEwO::SUBI; })
263 .Case([](arith::TruncFOp) { return MMAEwO::TRUNCF; })
264 .Default(std::nullopt);
265}
266
267/// Return true if the op is supported as elementwise op on MMAMatrix type.
269 return convertElementwiseOpToMMA(op).has_value();
270}
271
272/// Returns true if the extract strided slice op is supported with `mma.sync`
273/// path.
274static bool
275extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op) {
276
277 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
279 if (failed(warpMatrixInfo))
280 return false;
281
282 FailureOr<vector::ContractionOp> contractOp = nvgpu::getUserContract(op);
283 if (failed(contractOp))
284 return false;
285
286 // Handle vector.extract_strided_slice on registers containing
287 // matrixB and matrixC operands. vector.extract_strided_slice op
288 // is not supported on registers containing matrixA operands.
289 if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B)
290 return (cast<VectorType>(op->getResult(0).getType()) ==
291 cast<VectorType>((*contractOp).getRhs().getType()));
292 if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::C)
293 return (cast<VectorType>(op->getResult(0).getType()) ==
294 cast<VectorType>((*contractOp).getAcc().getType()));
295
296 return false;
297}
298
299static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
300 if (isa<scf::ForOp, scf::YieldOp>(op))
301 return true;
302 if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
303 return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferRead)
304 : transferReadSupportsMMAMatrixType(transferRead);
305 if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
306 return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferWrite)
307 : transferWriteSupportsMMAMatrixType(transferWrite);
308 if (auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op))
309 return useNvGpu &&
310 extractStridedSliceSupportsMMAMatrixType(extractStridedSlice);
311 if (auto contract = dyn_cast<vector::ContractionOp>(op))
312 return contractSupportsMMAMatrixType(contract, useNvGpu);
313 if (auto constant = dyn_cast<arith::ConstantOp>(op))
314 return constantSupportsMMAMatrixType(constant);
315 if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
317 if (auto signedExtend = dyn_cast<arith::ExtSIOp>(op))
319 if (auto unsignedExtend = dyn_cast<arith::ExtUIOp>(op))
321 if (auto fpExtend = dyn_cast<arith::ExtFOp>(op))
322 return fpExtendSupportsMMAMatrixType(fpExtend);
323 if (auto fpTrunc = dyn_cast<arith::TruncFOp>(op))
324 return fpTruncSupportsMMAMatrixType(fpTrunc);
326}
327
328/// Return an unsorted slice handling scf.for region differently than
329/// `getSlice`. In scf.for we only want to include as part of the slice elements
330/// that are part of the use/def chain.
333 const BackwardSliceOptions &backwardSliceOptions,
334 const ForwardSliceOptions &forwardSliceOptions) {
336 slice.insert(op);
337 unsigned currentIndex = 0;
338 SetVector<Operation *> backwardSlice;
339 SetVector<Operation *> forwardSlice;
340 while (currentIndex != slice.size()) {
341 auto *currentOp = (slice)[currentIndex];
342 // Compute and insert the backwardSlice starting from currentOp.
343 backwardSlice.clear();
344 LogicalResult result =
345 getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
346 assert(result.succeeded() && "expected a backward slice");
347 (void)result;
348 slice.insert_range(backwardSlice);
349
350 // Compute and insert the forwardSlice starting from currentOp.
351 forwardSlice.clear();
352 // Special case for ForOp, we don't want to include the whole region but
353 // only the value using the region arguments.
354 // TODO: We should refine this to only care about the region arguments being
355 // converted to matrix type.
356 if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) {
357 for (Value forOpResult : forOp.getResults())
358 getForwardSlice(forOpResult, &forwardSlice, forwardSliceOptions);
359 for (BlockArgument &arg : forOp.getRegionIterArgs())
360 getForwardSlice(arg, &forwardSlice, forwardSliceOptions);
361 } else {
362 getForwardSlice(currentOp, &forwardSlice, forwardSliceOptions);
363 }
364 slice.insert_range(forwardSlice);
365 ++currentIndex;
366 }
367 return slice;
368}
369
370// Analyze slice of operations based on convert op to figure out if the whole
371// slice can be converted to MMA operations.
373 bool useNvGpu) {
374 auto hasVectorDest = [](Operation *op) {
375 return llvm::any_of(op->getResultTypes(), llvm::IsaPred<VectorType>);
376 };
377 BackwardSliceOptions backwardSliceOptions;
378 backwardSliceOptions.filter = hasVectorDest;
379
380 auto hasVectorSrc = [](Operation *op) {
381 return llvm::any_of(op->getOperandTypes(), llvm::IsaPred<VectorType>);
382 };
383 ForwardSliceOptions forwardSliceOptions;
384 forwardSliceOptions.filter = hasVectorSrc;
385
386 SetVector<Operation *> opToConvert;
387 op->walk([&](Operation *nestedOp) {
388 if (!isa<vector::ContractionOp>(nestedOp) &&
390 return;
391 if (opToConvert.contains(nestedOp))
392 return;
393 SetVector<Operation *> dependentOps =
394 getSliceContract(nestedOp, backwardSliceOptions, forwardSliceOptions);
395 // If any instruction cannot use MMA matrix type drop the whole
396 // chain. MMA matrix are stored in an opaque type so they cannot be used
397 // by all operations.
398 if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) {
399 if (!supportsMMaMatrixType(op, useNvGpu)) {
400 LDBG() << "cannot convert op: " << *op;
401 return true;
402 }
403 return false;
404 }))
405 return;
406
407 opToConvert.insert_range(dependentOps);
408 });
409 // Sort the operations so that we can convert them in topological order.
410 return topologicalSort(opToConvert);
411}
412
413namespace {
414// Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted
415// to MMA matmul.
416struct PrepareContractToGPUMMA
417 : public OpRewritePattern<vector::ContractionOp> {
418 using Base::Base;
419
420 LogicalResult matchAndRewrite(vector::ContractionOp op,
421 PatternRewriter &rewriter) const override {
422 Location loc = op.getLoc();
423 Value lhs = op.getLhs(), rhs = op.getRhs(), res = op.getAcc();
424
425 // Set up the parallel/reduction structure in right form.
426 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
427 auto infer = [&](MapList m) {
428 return AffineMap::inferFromExprList(m, op.getContext());
429 };
430 AffineExpr m, n, k;
431 bindDims(rewriter.getContext(), m, n, k);
432 static constexpr std::array<int64_t, 2> perm = {1, 0};
433 auto iteratorTypes = op.getIteratorTypes().getValue();
434 SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
435 if (!(vector::isParallelIterator(iteratorTypes[0]) &&
436 vector::isParallelIterator(iteratorTypes[1]) &&
437 vector::isReductionIterator(iteratorTypes[2])))
438 return rewriter.notifyMatchFailure(op, "not a gemm contraction");
439 //
440 // Two outer parallel, one inner reduction (matmat flavor).
441 //
442 // This is the classical row-major matmul, nothing to do.
443 if (maps == infer({{m, k}, {k, n}, {m, n}}))
444 return rewriter.notifyMatchFailure(op, "contraction already prepared");
445 if (maps == infer({{m, k}, {n, k}, {m, n}})) {
446 rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
447 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
448 lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
449 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
450 rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
451 lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
452 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
453 std::swap(rhs, lhs);
454 rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
455 lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
456 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
457 std::swap(rhs, lhs);
458 rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
459 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
460 std::swap(lhs, rhs);
461 lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
462 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
463 std::swap(lhs, rhs);
464 } else {
465 // TODO: llvm_unreachable ?
466 return rewriter.notifyMatchFailure(op, "unexpected contraction case");
467 }
468 rewriter.replaceOpWithNewOp<vector::ContractionOp>(
469 op, lhs, rhs, res,
470 rewriter.getAffineMapArrayAttr(infer({{m, k}, {k, n}, {m, n}})),
471 op.getIteratorTypes());
472 return success();
473 }
474};
475
476// Fold transpose op into the transfer read op. NVGPU mma.sync op only supports
477// row-, column-, and row-major layout for matrixA, matrixB, and matrixC,
478// respectively. We can fold the transpose operation when loading the data from
479// Shared Memory to registers.
480struct CombineTransferReadOpTranspose final
481 : public OpRewritePattern<vector::TransposeOp> {
482 using Base::Base;
483
484 LogicalResult matchAndRewrite(vector::TransposeOp op,
485 PatternRewriter &rewriter) const override {
486 // Look through integer extend ops.
487 Value source = op.getVector();
488 Type resultType = op.getType();
489 Operation *extOp;
490 if ((extOp = source.getDefiningOp<arith::ExtSIOp>()) ||
491 (extOp = source.getDefiningOp<arith::ExtUIOp>()) ||
492 (extOp = source.getDefiningOp<arith::ExtFOp>())) {
493 source = extOp->getOperand(0);
494 resultType =
495 VectorType::get(cast<VectorType>(resultType).getShape(),
496 cast<VectorType>(source.getType()).getElementType());
497 }
498
499 auto transferReadOp = source.getDefiningOp<vector::TransferReadOp>();
500 if (!transferReadOp)
501 return rewriter.notifyMatchFailure(op, "no transfer read");
502
503 // TODO: support 0-d corner case.
504 if (transferReadOp.getTransferRank() == 0)
505 return rewriter.notifyMatchFailure(op, "0-D transfer read");
506
507 if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim())
508 return rewriter.notifyMatchFailure(op, "not inbounds transfer read");
509
510 AffineMap permutationMap =
511 AffineMap::getPermutationMap(op.getPermutation(), op.getContext());
512 AffineMap newMap =
513 permutationMap.compose(transferReadOp.getPermutationMap());
514
515 auto loc = op.getLoc();
516 Value result = vector::TransferReadOp::create(
517 rewriter, loc, resultType, transferReadOp.getBase(),
518 transferReadOp.getIndices(), AffineMapAttr::get(newMap),
519 transferReadOp.getPadding(), transferReadOp.getMask(),
520 transferReadOp.getInBoundsAttr())
521 .getResult();
522
523 // Fuse through the integer extend op.
524 if (extOp) {
525 if (isa<arith::ExtSIOp>(extOp))
526 result = arith::ExtSIOp::create(rewriter, loc, op.getType(), result)
527 .getResult();
528 else if (isa<arith::ExtUIOp>(extOp))
529 result = arith::ExtUIOp::create(rewriter, loc, op.getType(), result)
530 .getResult();
531 else
532 result = arith::ExtFOp::create(rewriter, loc, op.getType(), result)
533 .getResult();
534 }
535
536 rewriter.replaceOp(op, result);
537 return success();
538 }
539};
540
541} // namespace
542
543// MMA types have different layout based on how they are used in matmul ops.
544// Figure the right layout to use by looking at op uses.
545// TODO: Change the GPU dialect to abstract the layout at the this level and
546// only care about it during lowering to NVVM.
547static const char *inferFragType(Operation *op) {
548 // We can have arith.ext ops before reaching contract ops. See through them
549 // and other kinds of elementwise ops.
550 if (op->hasOneUse()) {
551 Operation *userOp = *op->user_begin();
552 if (userOp->hasTrait<OpTrait::Elementwise>())
553 return inferFragType(userOp);
554 }
555
556 for (Operation *users : op->getUsers()) {
557 auto contract = dyn_cast<vector::ContractionOp>(users);
558 if (!contract)
559 continue;
560 assert(op->getNumResults() == 1);
561 if (contract.getLhs() == op->getResult(0))
562 return "AOp";
563 if (contract.getRhs() == op->getResult(0))
564 return "BOp";
565 }
566 return "COp";
567}
568
569static LogicalResult
570convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
571 llvm::DenseMap<Value, Value> &valueMapping) {
572 OpBuilder::InsertionGuard g(rewriter);
573 rewriter.setInsertionPoint(op);
574
575 assert(op.getTransferRank() > 0 && "unexpected 0-d transfer");
577 "expected convertible operation");
578
579 AffineMap permutationMap = op.getPermutationMap();
580 std::optional<int64_t> stride =
581 getStaticallyKnownRowStride(op.getShapedType(), permutationMap);
582 if (!stride.has_value()) {
583 LDBG() << "no stride";
584 return rewriter.notifyMatchFailure(op, "no stride");
585 }
586
587 // transferReadSupportsMMAMatrixType ensures that either of the map results is
588 // the most minor dimension. Under this constraint, whether the map represents
589 // a transposed view can be inferred from whether the first result is the most
590 // minor memref dimension.
591 const bool isTranspose = isFirstResultLastMapDimension(permutationMap);
592
593 Value mappingResult = op.getResult();
594 auto elType = op.getVectorType().getElementType();
595 const char *fragType = inferFragType(op);
596 if (op->hasOneUse()) {
597 auto *user = *op->user_begin();
598 // Infer the signedness of the mma type from the integer extend.
599 if (isa<arith::ExtSIOp, arith::ExtUIOp>(user)) {
600 elType = IntegerType::get(
601 op.getContext(), cast<IntegerType>(elType).getWidth(),
602 isa<arith::ExtSIOp>(user) ? IntegerType::Signed
603 : IntegerType::Unsigned);
604 mappingResult = user->getResult(0);
605 }
606 }
607 gpu::MMAMatrixType type =
608 gpu::MMAMatrixType::get(op.getVectorType().getShape(), elType, fragType);
609 Value load = gpu::SubgroupMmaLoadMatrixOp::create(
610 rewriter, op.getLoc(), type, op.getBase(), op.getIndices(),
611 rewriter.getIndexAttr(*stride),
612 isTranspose ? rewriter.getUnitAttr() : UnitAttr());
613 valueMapping[mappingResult] = load;
614
615 LDBG() << "transfer read to: " << load;
616 return success();
617}
618
619static LogicalResult
620convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
621 llvm::DenseMap<Value, Value> &valueMapping) {
622 OpBuilder::InsertionGuard g(rewriter);
623 rewriter.setInsertionPoint(op);
624
626 std::optional<int64_t> stride =
627 getStaticallyKnownRowStride(op.getShapedType(), op.getPermutationMap());
628 if (!stride.has_value()) {
629 LDBG() << "no stride";
630 return rewriter.notifyMatchFailure(op, "no stride");
631 }
632
633 auto it = valueMapping.find(op.getVector());
634 if (it == valueMapping.end()) {
635 LDBG() << "no mapping";
636 return rewriter.notifyMatchFailure(op, "no mapping");
637 }
638
639 Value matrix = it->second;
640 auto store = gpu::SubgroupMmaStoreMatrixOp::create(
641 rewriter, op.getLoc(), matrix, op.getBase(), op.getIndices(),
642 rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr());
643 (void)store;
644
645 LDBG() << "transfer write to: " << store;
646
647 LDBG() << "erase: " << op;
648 rewriter.eraseOp(op);
649 return success();
650}
651
652/// Returns the vector type which represents a matrix fragment.
653static VectorType
654getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo &regInfo) {
655 SmallVector<int64_t> shape{regInfo.numRegistersPerFragment,
656 regInfo.elementsPerRegister};
657 Type elType = regInfo.registerLLVMType;
658 if (auto vecType = dyn_cast<VectorType>(elType))
659 elType = vecType.getElementType();
660 return VectorType::get(shape, elType);
661}
662
663/// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
664static LogicalResult
665convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op,
666 llvm::DenseMap<Value, Value> &valueMapping) {
667 OpBuilder::InsertionGuard g(rewriter);
668 rewriter.setInsertionPoint(op);
669
670 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
672 if (failed(warpMatrixInfo)) {
673 LDBG() << "no warpMatrixInfo";
674 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
675 }
676
677 FailureOr<nvgpu::FragmentElementInfo> regInfo =
678 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
679 if (failed(regInfo)) {
680 LDBG() << "not mma sync reg info";
681 return rewriter.notifyMatchFailure(op, "not mma sync reg info");
682 }
683
684 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
685 auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
686 if (!dense) {
687 LDBG() << "not a splat";
688 return rewriter.notifyMatchFailure(op, "not a splat");
689 }
690
691 Value result = arith::ConstantOp::create(
692 rewriter, op.getLoc(), vectorType,
693 DenseElementsAttr::get(vectorType, dense.getSplatValue<Attribute>()));
694 valueMapping[op.getResult()] = result;
695 return success();
696}
697
698/// Check if the loaded matrix operand requires transposed.
699/// Transposed Map Example:
700/// Example 1 : (..., d0, d1) -> (d1 * 1, d0 * 2)
701/// Example 2 : (d0, d1, d2, d3) -> (d3, d2)
702/// The code below checks if the output 2D is transposed using a generalized
703/// version : (d0, d1, dn, ..., dm, ...) -> (dm, dn)
704/// Returns : true; if m > n, false o.w.
705static FailureOr<bool> isTransposed(vector::TransferReadOp op) {
707
708 if (map.getNumResults() != 2) {
709 LDBG() << "Failed because the result of `vector.transfer_read` "
710 "is not a 2d operand";
711 return failure();
712 }
713
714 // Output 2D matrix dimensions in the order of d0, d1.
715 mlir::AffineExpr dM = map.getResult(0);
716 mlir::AffineExpr dN = map.getResult(1);
717
718 // Find the position of these expressions in the input.
719 auto exprM = dyn_cast<AffineDimExpr>(dM);
720 auto exprN = dyn_cast<AffineDimExpr>(dN);
721
722 if (!exprM || !exprN) {
723 LDBG() << "Failed because expressions are not affine dim "
724 "expressions, then transpose cannot be determined.";
725 return failure();
726 }
727
728 return exprM.getPosition() > exprN.getPosition();
729}
730
731static LogicalResult
732creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
733 llvm::DenseMap<Value, Value> &valueMapping) {
734 OpBuilder::InsertionGuard g(rewriter);
735 rewriter.setInsertionPoint(op);
736 Location loc = op->getLoc();
737
738 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
740 if (failed(warpMatrixInfo)) {
741 LDBG() << "no warpMatrixInfo";
742 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
743 }
744
745 FailureOr<nvgpu::FragmentElementInfo> regInfo =
746 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
747 if (failed(regInfo)) {
748 LDBG() << "not mma sync reg info";
749 return rewriter.notifyMatchFailure(op, "not mma sync reg info");
750 }
751
752 FailureOr<bool> transpose = isTransposed(op);
753 if (failed(transpose)) {
754 LDBG() << "failed to determine the transpose";
755 return rewriter.notifyMatchFailure(
756 op, "Op should likely not be converted to a nvgpu.ldmatrix call.");
757 }
758
759 FailureOr<nvgpu::LdMatrixParams> params =
760 nvgpu::getLdMatrixParams(*warpMatrixInfo, *transpose);
761
762 if (failed(params)) {
763 LDBG() << "failed to convert vector.transfer_read to ldmatrix. "
764 << "Op should likely not be converted to a nvgpu.ldmatrix call.";
765 return rewriter.notifyMatchFailure(
766 op, "failed to convert vector.transfer_read to ldmatrix; this op "
767 "likely should not be converted to a nvgpu.ldmatrix call.");
768 }
769
770 // Adjust the load offset.
771 auto laneId = gpu::LaneIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
772 FailureOr<AffineMap> offsets =
773 nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params);
774 if (failed(offsets)) {
775 LDBG() << "no offsets";
776 return rewriter.notifyMatchFailure(op, "no offsets");
777 }
778
779 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
780
782 getXferIndices<vector::TransferReadOp>(rewriter, op, *offsets, {laneId},
783 indices);
784
785 nvgpu::LdMatrixOp newOp =
786 nvgpu::LdMatrixOp::create(rewriter, loc, vectorType, op.getBase(),
787 indices, *transpose, params->numTiles);
788 valueMapping[op] = newOp->getResult(0);
789 return success();
790}
791
792static LogicalResult
793createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
794 llvm::DenseMap<Value, Value> &valueMapping) {
795 OpBuilder::InsertionGuard g(rewriter);
796 rewriter.setInsertionPoint(op);
797
798 Location loc = op.getLoc();
799 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
801 if (failed(warpMatrixInfo))
802 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
803 FailureOr<nvgpu::FragmentElementInfo> regInfo =
804 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
805 if (failed(regInfo)) {
806 return rewriter.notifyMatchFailure(
807 op, "Failed to deduce register fragment type during "
808 "conversion to distributed non-ldmatrix compatible load");
809 }
810
811 Value laneId = gpu::LaneIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
812
813 // This is the individual element type.
814 Type loadedElType = regInfo->registerLLVMType;
815 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
816
817 Value fill = arith::ConstantOp::create(
818 rewriter, op.getLoc(), vectorType.getElementType(),
819 rewriter.getZeroAttr(vectorType.getElementType()));
820 Value result =
821 vector::BroadcastOp::create(rewriter, op.getLoc(), vectorType, fill);
822
823 bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
824
825 // If we are not transposing, then we can use vectorized loads. Otherwise, we
826 // must load each element individually.
827 if (!isTransposeLoad) {
828 if (!isa<VectorType>(loadedElType)) {
829 loadedElType = VectorType::get({1}, loadedElType);
830 }
831
832 for (int i = 0; i < vectorType.getShape()[0]; i++) {
833 FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
834 rewriter, op.getLoc(), *warpMatrixInfo);
835 if (failed(coords))
836 return rewriter.notifyMatchFailure(op, "no coords");
837
838 Value logicalValueId = arith::ConstantOp::create(
839 rewriter, loc, rewriter.getIndexType(),
840 rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
841 SmallVector<Value, 4> newIndices;
843 rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
844
845 Value el = vector::LoadOp::create(rewriter, loc, loadedElType,
846 op.getBase(), newIndices);
847 result = vector::InsertOp::create(rewriter, loc, el, result, i);
848 }
849 } else {
850 if (auto vecType = dyn_cast<VectorType>(loadedElType)) {
851 loadedElType = vecType.getElementType();
852 }
853 for (int i = 0; i < vectorType.getShape()[0]; i++) {
854 for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1];
855 innerIdx++) {
856
857 Value logicalValueId = arith::ConstantOp::create(
858 rewriter, loc, rewriter.getIndexType(),
859 rewriter.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx));
860 FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
861 rewriter, op.getLoc(), *warpMatrixInfo);
862 if (failed(coords))
863 return rewriter.notifyMatchFailure(op, "no coords");
864
865 SmallVector<Value, 4> newIndices;
867 rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
868 Value el = memref::LoadOp::create(rewriter, op.getLoc(), loadedElType,
869 op.getBase(), newIndices);
870 result = vector::InsertOp::create(rewriter, op.getLoc(), el, result,
871 ArrayRef<int64_t>{i, innerIdx});
872 }
873 }
874 }
875
876 valueMapping[op.getResult()] = result;
877 return success();
878}
879
880/// Return true if this is a shared memory memref type.
881static bool isSharedMemory(MemRefType type) {
882 auto addressSpace =
883 dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
884 return addressSpace &&
885 addressSpace.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
886}
887
888/// Converts a `vector.transfer_read` operation directly to either a
889/// `vector.load` or a `nvgpu.ldmatrix` operation. This function should only be
890/// used when converting to `nvgpu.mma.sync` operations.
891static LogicalResult
892convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op,
893 llvm::DenseMap<Value, Value> &valueMapping) {
894 OpBuilder::InsertionGuard g(rewriter);
895 rewriter.setInsertionPoint(op);
896
897 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
899 if (failed(warpMatrixInfo))
900 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
901
902 bool isLdMatrixCompatible =
903 isSharedMemory(cast<MemRefType>(op.getBase().getType())) &&
904 nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128;
905
906 VectorType vecTy = op.getVectorType();
907 int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth();
908
909 // When we are transposing the B operand, ldmatrix will only work if we have
910 // at least 8 rows to read and the width to read for the transpose is 128
911 // bits.
912 if (!op.getPermutationMap().isMinorIdentity() &&
913 (bitWidth != 16 || vecTy.getDimSize(1) < 8 ||
914 vecTy.getDimSize(0) * bitWidth < 128))
915 isLdMatrixCompatible = false;
916
917 if (!isLdMatrixCompatible)
918 return createNonLdMatrixLoads(rewriter, op, valueMapping);
919
920 return creatLdMatrixCompatibleLoads(rewriter, op, valueMapping);
921}
922
923static LogicalResult
924convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
925 llvm::DenseMap<Value, Value> &valueMapping) {
926 OpBuilder::InsertionGuard g(rewriter);
927 rewriter.setInsertionPoint(op);
928
929 Location loc = op->getLoc();
930 auto it = valueMapping.find(op.getVector());
931 if (it == valueMapping.end())
932 return rewriter.notifyMatchFailure(op, "no mapping");
933 Value matrix = it->second;
934
935 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
937 if (failed(warpMatrixInfo))
938 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
939 FailureOr<nvgpu::FragmentElementInfo> regInfo =
940 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
941 if (failed(regInfo))
942 return rewriter.notifyMatchFailure(op, "not mma sync reg info");
943
944 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
945 Value laneId = gpu::LaneIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
946
947 for (unsigned i = 0; i < vectorType.getShape()[0]; i++) {
948 Value logicalValueId = arith::ConstantOp::create(
949 rewriter, loc, rewriter.getIndexType(),
950 rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
951 FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
952 rewriter, op.getLoc(), *warpMatrixInfo);
953 if (failed(coords))
954 return rewriter.notifyMatchFailure(op, "no coords");
955
956 Value el =
957 vector::ExtractOp::create(rewriter, loc, matrix, ArrayRef<int64_t>{i});
958 SmallVector<Value, 4> newIndices;
960 rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
961 vector::StoreOp::create(rewriter, loc, el, op.getBase(), newIndices);
962 }
963
964 LDBG() << "erase: " << op;
965 rewriter.eraseOp(op);
966 return success();
967}
968
970 SmallVectorImpl<int64_t> &results) {
971 for (auto attr : arrayAttr)
972 results.push_back(cast<IntegerAttr>(attr).getInt());
973}
974
975static LogicalResult
977 vector::ExtractStridedSliceOp op,
978 llvm::DenseMap<Value, Value> &valueMapping) {
979 OpBuilder::InsertionGuard g(rewriter);
980 rewriter.setInsertionPoint(op);
981
982 Location loc = op->getLoc();
983
984 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
986 if (failed(warpMatrixInfo))
987 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
988
989 FailureOr<nvgpu::FragmentElementInfo> mmaSyncFragmentInfo =
990 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
991 if (failed(mmaSyncFragmentInfo))
992 return rewriter.notifyMatchFailure(op, "no mmaSyncFragmentInfo");
993
994 // Find the vector.transer_read whose result vector is being sliced.
995 auto transferReadOp = op.getSource().getDefiningOp<vector::TransferReadOp>();
996 if (!transferReadOp)
997 return rewriter.notifyMatchFailure(op, "no transfer read");
998
999 warpMatrixInfo = nvgpu::getWarpMatrixInfo(transferReadOp);
1000 if (failed(warpMatrixInfo))
1001 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
1002
1003 FailureOr<nvgpu::FragmentElementInfo> ldFragmentInfo =
1004 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
1005 if (failed(ldFragmentInfo))
1006 return rewriter.notifyMatchFailure(op, "no ldFragmentInfo");
1007
1008 assert(
1009 (mmaSyncFragmentInfo->elementsPerRegister ==
1010 ldFragmentInfo->elementsPerRegister) &&
1011 "Number of elements per register should be same for load and mma.sync");
1012
1013 // Create vector.extract_strided_slice op for thread-owned fragments.
1014 std::array<int64_t, 2> strides = {1,
1015 1}; // stride for extract slice is always 1.
1016 std::array<int64_t, 2> sliceShape = {
1017 mmaSyncFragmentInfo->numRegistersPerFragment,
1018 mmaSyncFragmentInfo->elementsPerRegister};
1019 auto it = valueMapping.find(transferReadOp);
1020 if (it == valueMapping.end())
1021 return rewriter.notifyMatchFailure(op, "no mapping");
1022 auto sourceVector = it->second;
1023
1024 // offset and sizes at warp-level of onwership.
1025 SmallVector<int64_t> offsets;
1026 populateFromInt64AttrArray(op.getOffsets(), offsets);
1027
1029 populateFromInt64AttrArray(op.getSizes(), sizes);
1030 ArrayRef<int64_t> warpVectorShape = op.getSourceVectorType().getShape();
1031
1032 // Compute offset in vector registers. Note that the mma.sync vector registers
1033 // are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector
1034 // registers can only be sliced along numberOfFragments, i.e., sliceOffset[0].
1035 std::array<int64_t, 2> sliceOffset = {0, 0};
1036
1037 if (offsets[0] && offsets[1])
1038 return op->emitError() << "Slicing fragments in 2D is not supported. ";
1039 if (offsets[0])
1040 sliceOffset[0] = (warpVectorShape[0] / offsets[0]);
1041 else if (offsets[1])
1042 sliceOffset[0] = (warpVectorShape[1] / offsets[1]);
1043
1044 Value newOp = vector::ExtractStridedSliceOp::create(
1045 rewriter, loc, sourceVector, sliceOffset, sliceShape, strides);
1046
1047 valueMapping[op] = newOp;
1048 return success();
1049}
1050
1051static LogicalResult
1052convertContractOp(RewriterBase &rewriter, vector::ContractionOp op,
1053 llvm::DenseMap<Value, Value> &valueMapping) {
1054 OpBuilder::InsertionGuard g(rewriter);
1055 rewriter.setInsertionPoint(op);
1056
1057 auto itA = valueMapping.find(op.getLhs());
1058 auto itB = valueMapping.find(op.getRhs());
1059 auto itC = valueMapping.find(op.getAcc());
1060 if (itA == valueMapping.end() || itB == valueMapping.end() ||
1061 itC == valueMapping.end())
1062 return rewriter.notifyMatchFailure(op, "no mapping");
1063 Value opA = itA->second, opB = itB->second, opC = itC->second;
1064 Value matmul = gpu::SubgroupMmaComputeOp::create(rewriter, op.getLoc(),
1065 opC.getType(), opA, opB, opC,
1066 /*a_transpose=*/UnitAttr(),
1067 /*b_transpose=*/UnitAttr());
1068 valueMapping[op.getResult()] = matmul;
1069 return success();
1070}
1071
1072static LogicalResult
1073convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op,
1074 llvm::DenseMap<Value, Value> &valueMapping) {
1075 OpBuilder::InsertionGuard g(rewriter);
1076 rewriter.setInsertionPoint(op);
1077
1078 auto itA = valueMapping.find(op.getLhs());
1079 auto itB = valueMapping.find(op.getRhs());
1080 auto itC = valueMapping.find(op.getAcc());
1081 if (itA == valueMapping.end() || itB == valueMapping.end() ||
1082 itC == valueMapping.end())
1083 return rewriter.notifyMatchFailure(op, "no mapping");
1084 Value opA = itA->second, opB = itB->second, opC = itC->second;
1085 int64_t m = cast<VectorType>(op.getLhs().getType()).getShape()[0];
1086 int64_t n = cast<VectorType>(op.getRhs().getType()).getShape()[0];
1087 int64_t k = cast<VectorType>(op.getLhs().getType()).getShape()[1];
1088 Value matmul = nvgpu::MmaSyncOp::create(rewriter, op.getLoc(), opA, opB, opC,
1089 rewriter.getI64ArrayAttr({m, n, k}));
1090 valueMapping[op.getResult()] = matmul;
1091 return success();
1092}
1093
1094/// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
1095static LogicalResult
1096convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op,
1097 llvm::DenseMap<Value, Value> &valueMapping) {
1098 OpBuilder::InsertionGuard g(rewriter);
1099 rewriter.setInsertionPoint(op);
1100
1102
1103 auto splat =
1104 cast<SplatElementsAttr>(op.getValue()).getSplatValue<TypedAttr>();
1105 auto scalarConstant =
1106 arith::ConstantOp::create(rewriter, op.getLoc(), splat.getType(), splat);
1107 const char *fragType = inferFragType(op);
1108 auto vecType = cast<VectorType>(op.getType());
1110 vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1111 auto matrix = gpu::SubgroupMmaConstantMatrixOp::create(rewriter, op.getLoc(),
1112 type, scalarConstant);
1113 valueMapping[op.getResult()] = matrix;
1114 return success();
1115}
1116
1117/// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
1118static LogicalResult
1119convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op,
1120 llvm::DenseMap<Value, Value> &valueMapping) {
1121 OpBuilder::InsertionGuard g(rewriter);
1122 rewriter.setInsertionPoint(op);
1123
1125
1126 const char *fragType = inferFragType(op);
1127 auto vecType = op.getResultVectorType();
1129 vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1130 auto matrix = gpu::SubgroupMmaConstantMatrixOp::create(rewriter, op.getLoc(),
1131 type, op.getSource());
1132 valueMapping[op.getResult()] = matrix;
1133 return success();
1134}
1135
1136// Replace ForOp with a new ForOp with extra operands. The YieldOp is not
1137// updated and needs to be updated separately for the loop to be correct.
1138static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
1139 scf::ForOp loop,
1140 ValueRange newInitArgs) {
1141 OpBuilder::InsertionGuard g(rewriter);
1142 rewriter.setInsertionPoint(loop);
1143
1144 // Create a new loop before the existing one, with the extra operands.
1145 rewriter.setInsertionPoint(loop);
1146 auto operands = llvm::to_vector<4>(loop.getInitArgs());
1147 llvm::append_range(operands, newInitArgs);
1148 scf::ForOp newLoop =
1149 scf::ForOp::create(rewriter, loop.getLoc(), loop.getLowerBound(),
1150 loop.getUpperBound(), loop.getStep(), operands);
1151 rewriter.eraseBlock(newLoop.getBody());
1152
1153 newLoop.getRegion().getBlocks().splice(
1154 newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks());
1155 for (Value operand : newInitArgs)
1156 newLoop.getBody()->addArgument(operand.getType(), operand.getLoc());
1157
1158 for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
1159 loop.getNumResults())))
1160 rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
1161
1162 LDBG() << "newLoop now: " << newLoop;
1163 LDBG() << "stripped scf.for: " << loop;
1164 LDBG() << "erase: " << loop;
1165
1166 rewriter.eraseOp(loop);
1167 return newLoop;
1168}
1169
1170static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op,
1171 llvm::DenseMap<Value, Value> &valueMapping) {
1172 OpBuilder::InsertionGuard g(rewriter);
1173 rewriter.setInsertionPoint(op);
1174
1175 SmallVector<Value> newOperands;
1177 for (const auto &operand : llvm::enumerate(op.getInitArgs())) {
1178 auto it = valueMapping.find(operand.value());
1179 if (it == valueMapping.end()) {
1180 LDBG() << "no value mapping for: " << operand.value();
1181 continue;
1182 }
1183 argMapping.push_back(std::make_pair(
1184 operand.index(), op.getInitArgs().size() + newOperands.size()));
1185 newOperands.push_back(it->second);
1186 }
1187
1188 scf::ForOp newForOp = replaceForOpWithNewSignature(rewriter, op, newOperands);
1189 Block &loopBody = *newForOp.getBody();
1190 for (auto mapping : argMapping) {
1191 valueMapping[newForOp.getResult(mapping.first)] =
1192 newForOp.getResult(mapping.second);
1193 valueMapping[loopBody.getArgument(mapping.first +
1194 newForOp.getNumInductionVars())] =
1195 loopBody.getArgument(mapping.second + newForOp.getNumInductionVars());
1196 }
1197
1198 LDBG() << "scf.for to: " << newForOp;
1199 return success();
1200}
1201
1202static LogicalResult
1203convertYieldOp(RewriterBase &rewriter, scf::YieldOp op,
1204 llvm::DenseMap<Value, Value> &valueMapping) {
1205 OpBuilder::InsertionGuard g(rewriter);
1206 rewriter.setInsertionPoint(op);
1207
1208 auto loop = cast<scf::ForOp>(op->getParentOp());
1209 auto yieldOperands = llvm::to_vector<4>(op.getOperands());
1210 for (const auto &operand : llvm::enumerate(op.getOperands())) {
1211 auto it = valueMapping.find(operand.value());
1212 if (it == valueMapping.end())
1213 continue;
1214 // Replace the yield of old value with the for op argument to make it easier
1215 // to remove the dead code.
1216 yieldOperands[operand.index()] = loop.getInitArgs()[operand.index()];
1217 yieldOperands.push_back(it->second);
1218 }
1219 scf::YieldOp::create(rewriter, op.getLoc(), yieldOperands);
1220
1221 LDBG() << "erase: " << op;
1222 rewriter.eraseOp(op);
1223 return success();
1224}
1225
1226/// Convert an elementwise op to the equivalent elementwise op on MMA matrix.
1227static LogicalResult
1229 gpu::MMAElementwiseOp opType,
1230 llvm::DenseMap<Value, Value> &valueMapping) {
1231 OpBuilder::InsertionGuard g(rewriter);
1232 rewriter.setInsertionPoint(op);
1233
1234 SmallVector<Value> matrixOperands;
1235 for (Value operand : op->getOperands()) {
1236 auto it = valueMapping.find(operand);
1237 if (it == valueMapping.end())
1238 return rewriter.notifyMatchFailure(op, "no mapping");
1239 matrixOperands.push_back(it->second);
1240 }
1241 auto resultType = cast<gpu::MMAMatrixType>(matrixOperands[0].getType());
1242 if (opType == gpu::MMAElementwiseOp::EXTF ||
1243 opType == gpu::MMAElementwiseOp::TRUNCF) {
1244 // The floating point extension and truncation has a different result type.
1245 auto vectorType = cast<VectorType>(op->getResultTypes()[0]);
1246 resultType = gpu::MMAMatrixType::get(resultType.getShape(),
1247 vectorType.getElementType(),
1248 resultType.getOperand());
1249 }
1250
1251 Value newOp = gpu::SubgroupMmaElementwiseOp::create(
1252 rewriter, op->getLoc(), resultType, matrixOperands, opType);
1253 valueMapping[op->getResult(0)] = newOp;
1254 return success();
1255}
1256
1258 bool useNvGpu) {
1259 if (!useNvGpu) {
1260 patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>(
1261 patterns.getContext());
1262 return;
1263 }
1265 patterns.add<CombineTransferReadOpTranspose>(patterns.getContext());
1266}
1267
1269 Operation *rootOp) {
1270 SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/false);
1271 llvm::DenseMap<Value, Value> valueMapping;
1272
1273 auto globalRes = LogicalResult::success();
1274 for (Operation *op : ops) {
1275 LDBG() << "Process op: " << *op;
1276 // Apparently callers do not want to early exit on failure here.
1277 auto res = LogicalResult::success();
1278 if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
1279 res = convertTransferReadOp(rewriter, transferRead, valueMapping);
1280 } else if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) {
1281 res = convertTransferWriteOp(rewriter, transferWrite, valueMapping);
1282 } else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) {
1283 res = convertContractOp(rewriter, contractOp, valueMapping);
1284 } else if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
1285 res = convertConstantOp(rewriter, constantOp, valueMapping);
1286 } else if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) {
1287 res = convertBroadcastOp(rewriter, broadcastOp, valueMapping);
1288 } else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
1289 res = convertForOp(rewriter, forOp, valueMapping);
1290 } else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
1291 res = convertYieldOp(rewriter, yieldOp, valueMapping);
1292 } else if (auto elementwiseType = convertElementwiseOpToMMA(op)) {
1293 res = convertElementwiseOp(rewriter, op, *elementwiseType, valueMapping);
1294 }
1295 if (failed(res))
1296 globalRes = failure();
1297 }
1298 return globalRes;
1299}
1300
1302 Operation *rootOp) {
1303 SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/true);
1304 llvm::DenseMap<Value, Value> valueMapping;
1305 for (Operation *op : ops) {
1307 .Case([&](vector::TransferReadOp transferReadOp) {
1308 return convertTransferReadToLoads(rewriter, transferReadOp,
1309 valueMapping);
1310 })
1311 .Case([&](vector::TransferWriteOp transferWriteOp) {
1312 return convertTransferWriteToStores(rewriter, transferWriteOp,
1313 valueMapping);
1314 })
1315 .Case([&](vector::ExtractStridedSliceOp extractStridedSliceOp) {
1316 return convertExtractStridedSlice(rewriter, extractStridedSliceOp,
1317 valueMapping);
1318 })
1319 .Case([&](vector::ContractionOp contractionOp) {
1320 return convertContractOpToMmaSync(rewriter, contractionOp,
1321 valueMapping);
1322 })
1323 .Case([&](scf::ForOp forOp) {
1324 return convertForOp(rewriter, forOp, valueMapping);
1325 })
1326 .Case([&](scf::YieldOp yieldOp) {
1327 return convertYieldOp(rewriter, yieldOp, valueMapping);
1328 })
1329 .Case([&](arith::ConstantOp constOp) {
1330 return convertConstantOpMmaSync(rewriter, constOp, valueMapping);
1331 })
1332 .Default([&](Operation *op) {
1333 return op->emitError() << "unhandled vector to mma type: " << *op;
1334 })
1335 .failed()) {
1336 return op->emitOpError()
1337 << "failed to convert op during vector-to-nvgpu conversion";
1338 }
1339 }
1340 return success();
1341}
1342
1343namespace {
1344
1345struct ConvertVectorToGPUPass
1346 : public impl::ConvertVectorToGPUBase<ConvertVectorToGPUPass> {
1347
1348 explicit ConvertVectorToGPUPass(bool useNvGpu_) {
1349 useNvGpu.setValue(useNvGpu_);
1350 }
1351
1352 void runOnOperation() override {
1353 RewritePatternSet patterns(&getContext());
1354 populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue());
1355 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
1356 return signalPassFailure();
1357
1358 IRRewriter rewriter(&getContext());
1359 if (useNvGpu) {
1360 if (failed(
1361 convertVectorToNVVMCompatibleMMASync(rewriter, getOperation())))
1362 return signalPassFailure();
1363 return;
1364 }
1365 (void)convertVectorToMMAOps(rewriter, getOperation());
1366 }
1367};
1368
1369} // namespace
1370
1371std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass(bool useNvGpu) {
1372 return std::make_unique<ConvertVectorToGPUPass>(useNvGpu);
1373}
return success()
lhs
ArrayAttr()
b getContext())
auto load
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
static LogicalResult convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op, llvm::DenseMap< Value, Value > &valueMapping)
static std::optional< gpu::MMAElementwiseOp > convertElementwiseOpToMMA(Operation *op)
Return the MMA elementwise enum associated with op if it is supported.
static LogicalResult convertContractOp(RewriterBase &rewriter, vector::ContractionOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool fpTruncSupportsMMAMatrixType(arith::TruncFOp extOp)
static const char * inferFragType(Operation *op)
static LogicalResult convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool isSharedMemory(MemRefType type)
Return true if this is a shared memory memref type.
static VectorType getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo &regInfo)
Returns the vector type which represents a matrix fragment.
static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp)
static bool supportsMMaMatrixType(Operation *op, bool useNvGpu)
static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp)
Return true if the constant is a splat to a 2D vector so that it can be converted to a MMA constant m...
static bool contractSupportsMMAMatrixType(vector::ContractionOp contract, bool useNvGpu)
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
static FailureOr< bool > isTransposed(vector::TransferReadOp op)
Check if the loaded matrix operand requires transposed.
static LogicalResult convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp)
Return true if this integer extend op can be folded into a contract op.
static LogicalResult convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
Converts a vector.transfer_read operation directly to either a vector.load or a nvgpu....
static LogicalResult convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
static LogicalResult creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, scf::ForOp loop, ValueRange newInitArgs)
static bool transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp)
static LogicalResult convertElementwiseOp(RewriterBase &rewriter, Operation *op, gpu::MMAElementwiseOp opType, llvm::DenseMap< Value, Value > &valueMapping)
Convert an elementwise op to the equivalent elementwise op on MMA matrix.
static bool isFirstResultLastMapDimension(AffineMap permutationMap)
static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp)
static bool extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op)
Returns true if the extract strided slice op is supported with mma.sync path.
static LogicalResult convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static bool elementwiseSupportsMMAMatrixType(Operation *op)
Return true if the op is supported as elementwise op on MMAMatrix type.
static LogicalResult convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static LogicalResult convertYieldOp(RewriterBase &rewriter, scf::YieldOp op, llvm::DenseMap< Value, Value > &valueMapping)
static SetVector< Operation * > getOpToConvert(mlir::Operation *op, bool useNvGpu)
static LogicalResult convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op, llvm::DenseMap< Value, Value > &valueMapping)
static std::optional< int64_t > getStaticallyKnownRowStride(ShapedType type, AffineMap permutationMap)
static SetVector< Operation * > getSliceContract(Operation *op, const BackwardSliceOptions &backwardSliceOptions, const ForwardSliceOptions &forwardSliceOptions)
Return an unsorted slice handling scf.for region differently than getSlice.
static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp, AffineMap offsetMap, ArrayRef< Value > dimValues, SmallVector< Value, 4 > &indices)
For a vector TransferOpType xferOp, an empty indices vector, and an AffineMap representing offsets to...
static LogicalResult convertExtractStridedSlice(RewriterBase &rewriter, vector::ExtractStridedSliceOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp)
Return true if this is a broadcast from scalar to a 2D vector.
static LogicalResult createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
MLIRContext * getContext() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
UnitAttr getUnitAttr()
Definition Builders.cpp:102
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:93
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:328
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:368
MLIRContext * getContext() const
Definition Builders.h:56
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:285
IndexType getIndexType()
Definition Builders.cpp:55
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition Builders.cpp:322
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:379
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:778
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition Operation.h:878
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:436
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:244
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
Definition Operation.h:426
result_type_range getResultTypes()
Definition Operation.h:457
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:407
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:826
user_range getUsers()
Returns a range of all users.
Definition Operation.h:902
user_iterator user_begin()
Definition Operation.h:898
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:433
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition GPUDialect.h:131
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
FailureOr< vector::ContractionOp > getUserContract(Operation *op)
Returns the first user of the op that is vector.contract.
Definition MMAUtils.cpp:48
FailureOr< WarpMatrixInfo > getWarpMatrixInfo(Operation *op)
If op is a vector.transfer_write, return the WarpMatrixInfo for the vector operand.
Definition MMAUtils.cpp:56
bool canLowerToWarpMatrixOperation(vector::TransferWriteOp op)
Returns the number of bits in a single tile row.
Definition MMAUtils.cpp:296
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Definition VectorOps.h:156
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Definition VectorOps.h:151
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contract a, b, c with row-major matmul semantics to a contraction with M...
Include the generated interface declarations.
void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, bool useNvGpu=false)
Patterns to transform vector ops into a canonical form to convert to MMA matrix operations.
LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:123
SliceOptions ForwardSliceOptions
LogicalResult convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter, Operation *rootOp)
Convert vector ops ops nested under rootOp to vector and GPU operaitons compatible with the nvvm....
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:136
std::unique_ptr< Pass > createConvertVectorToGPUPass(bool useNvGpu=false)
Convert from vector to GPU ops.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult convertVectorToMMAOps(RewriterBase &rewriter, Operation *rootOp)
Convert vector ops to MMA matrix operations nested under rootOp.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This trait tags element-wise ops on vectors or tensors.
TransitiveFilter filter