MLIR 23.0.0git
VectorToGPU.cpp
Go to the documentation of this file.
1//===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements lowering of vector operations to GPU dialect ops.
10//
11//===----------------------------------------------------------------------===//
12
14
28#include "mlir/IR/Builders.h"
29#include "mlir/IR/Region.h"
30#include "mlir/Pass/Pass.h"
32#include "llvm/ADT/STLExtras.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/Support/DebugLog.h"
35
36#define DEBUG_TYPE "vector-to-gpu"
37
38namespace mlir {
39#define GEN_PASS_DEF_CONVERTVECTORTOGPU
40#include "mlir/Conversion/Passes.h.inc"
41} // namespace mlir
42
43using namespace mlir;
44
45/// For a vector TransferOpType `xferOp`, an empty `indices` vector, and an
46/// AffineMap representing offsets to apply to indices, the function fills
47/// `indices` with the original indices plus the offsets. The offsets are
48/// applied by taking into account the permutation map of the transfer op. If
49/// the `offsetMap` has dimension placeholders, those should be provided in
50/// `dimValues`.
51template <typename TransferOpType>
52static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp,
53 AffineMap offsetMap, ArrayRef<Value> dimValues,
55 indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end());
56 Location loc = xferOp.getLoc();
57 unsigned offsetsIdx = 0;
58 for (auto expr : xferOp.getPermutationMap().getResults()) {
59 if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
60 Value prevIdx = indices[dim.getPosition()];
61 SmallVector<OpFoldResult, 3> dims(dimValues);
62 dims.push_back(prevIdx);
63 AffineExpr d0 = rewriter.getAffineDimExpr(offsetMap.getNumDims());
64 indices[dim.getPosition()] = affine::makeComposedAffineApply(
65 rewriter, loc, d0 + offsetMap.getResult(offsetsIdx++), dims);
66 continue;
67 }
68 }
69}
70
71// Return true if the contract op can be convert to MMA matmul.
72static bool contractSupportsMMAMatrixType(vector::ContractionOp contract,
73 bool useNvGpu) {
74 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
75 auto infer = [&](MapList m) {
76 return AffineMap::inferFromExprList(m, contract.getContext());
77 };
78 AffineExpr m, n, k;
79 bindDims(contract.getContext(), m, n, k);
80 auto iteratorTypes = contract.getIteratorTypes().getValue();
81 if (!(vector::isParallelIterator(iteratorTypes[0]) &&
82 vector::isParallelIterator(iteratorTypes[1]) &&
83 vector::isReductionIterator(iteratorTypes[2])))
84 return false;
85
86 // The contract needs to represent a matmul to be able to convert to
87 // MMAMatrix matmul.
88 if (!useNvGpu &&
89 contract.getIndexingMapsArray() != infer({{m, k}, {k, n}, {m, n}}))
90 return false;
91 if (useNvGpu &&
92 contract.getIndexingMapsArray() != infer({{m, k}, {n, k}, {m, n}}))
93 return false;
94
95 return true;
96}
97
98// Test whether the permutation map's first result corresponds to its last
99// dimension.
100//
101// In contexts where we only accept maps that have the last (most minor)
102// dimension as exactly one of the two results, this is sufficient to classify
103// whether it represents a transpose.
104static bool isFirstResultLastMapDimension(AffineMap permutationMap) {
105 MLIRContext *ctx = permutationMap.getContext();
106 const unsigned nDim = permutationMap.getNumDims();
107 if (0 == nDim || permutationMap.getResults().empty())
108 return false;
109 return permutationMap.getResult(0) == getAffineDimExpr(nDim - 1, ctx);
110}
111
112// Return the `leadDimension` (row stride) implied by |permutationMap| for
113// |type|, if |type| is a memref with a statically-known layout.
114//
115// The `leadDimension` is the stride (in elements) between consecutive rows in
116// the 2D view described by |permutationMap|. This helper supports the subset
117// of maps permitted by vector.transfer_read:
118// - Exactly 2 results.
119// - Each result is either an affine dimension or the constant 0 (broadcast).
120//
121// Constraints:
122// - Requires the most minor memref stride to be 1.
123//
124// Broadcast:
125// - If either result is constant 0, the implied `leadDimension` is 0.
126static std::optional<int64_t>
127getStaticallyKnownRowStride(ShapedType type, AffineMap permutationMap) {
128 auto memrefType = dyn_cast<MemRefType>(type);
129 if (!memrefType)
130 return std::nullopt;
131 // If the memref is 0 or 1D the horizontal stride is 0.
132 if (memrefType.getRank() < 2)
133 return 0;
134 int64_t offset = 0;
135 SmallVector<int64_t> strides;
136 if (failed(memrefType.getStridesAndOffset(strides, offset)) ||
137 strides.back() != 1)
138 return std::nullopt;
139
140 if (permutationMap.getNumResults() != 2)
141 return std::nullopt;
142
143 unsigned strideIndex = strides.size();
144
145 for (AffineExpr result : permutationMap.getResults()) {
146 if (auto cst = dyn_cast<AffineConstantExpr>(result)) {
147 // Constant value must be zero.
148 if (0 != cst.getValue())
149 return std::nullopt;
150 // A broadcast result forces row stride to 0.
151 return 0;
152 }
153 auto dim = dyn_cast<AffineDimExpr>(result);
154 // Only Dim & Const results are supported.
155 if (!dim)
156 return std::nullopt;
157 strideIndex = std::min(strideIndex, dim.getPosition());
158 }
159
160 // Structural validity check: ensure that the map selects at least one
161 // dimension more major than the most minor dimension. This also excludes
162 // degenerate cases where both results map to the most minor dimension.
163 if (strideIndex + 1 >= strides.size())
164 return std::nullopt;
165
166 const int64_t stride = strides[strideIndex];
167 if (stride == ShapedType::kDynamic)
168 return std::nullopt;
169 return stride;
170}
171
172// Return true if the transfer op can be converted to a MMA matrix load.
173static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
174 if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
175 readOp.getVectorType().getRank() != 2)
176 return false;
177
178 AffineMap permutationMap = readOp.getPermutationMap();
179 if (!getStaticallyKnownRowStride(readOp.getShapedType(), permutationMap))
180 return false;
181
182 // Only allow integer types if the signedness can be inferred.
183 if (readOp.getVectorType().getElementType().isInteger(8))
184 if (!readOp->hasOneUse() || (!isa<arith::ExtSIOp>(*readOp->user_begin()) &&
185 !isa<arith::ExtUIOp>(*readOp->user_begin())))
186 return false;
187
188 MLIRContext *ctx = readOp.getContext();
189 AffineExpr innerDim = getAffineDimExpr(permutationMap.getNumDims() - 1, ctx);
190 return llvm::is_contained(permutationMap.getResults(), innerDim);
191}
192
193// Return true if the transfer op can be converted to a MMA matrix store.
194static bool
195transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
196 // TODO: support 0-d corner case.
197 if (writeOp.getTransferRank() == 0)
198 return false;
199
200 if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() ||
201 writeOp.getVectorType().getRank() != 2)
202 return false;
203
204 AffineMap permutationMap = writeOp.getPermutationMap();
205 std::optional<int64_t> stride =
206 getStaticallyKnownRowStride(writeOp.getShapedType(), permutationMap);
207 // Stride of zero means broadcast which is not permitted for writes.
208 if (!stride.has_value() || stride.value() == 0)
209 return false;
210
211 MLIRContext *ctx = writeOp.getContext();
212 AffineExpr innerDim = getAffineDimExpr(permutationMap.getNumDims() - 1, ctx);
213 // TODO: Support transpose once it is added to GPU dialect ops.
214 return permutationMap.getResult(1) == innerDim;
215}
216
217/// Return true if the constant is a splat to a 2D vector so that it can be
218/// converted to a MMA constant matrix op.
219static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) {
220 auto vecType = dyn_cast<VectorType>(constantOp.getType());
221 if (!vecType || vecType.getRank() != 2)
222 return false;
223 return isa<SplatElementsAttr>(constantOp.getValue());
224}
225
226/// Return true if this is a broadcast from scalar to a 2D vector.
227static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {
228 return broadcastOp.getResultVectorType().getRank() == 2;
229}
230
231/// Return true if this integer extend op can be folded into a contract op.
232template <typename ExtOpTy>
233static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) {
234 auto transferReadOp =
235 extOp.getOperand().template getDefiningOp<vector::TransferReadOp>();
236 if (!transferReadOp)
237 return false;
238 return llvm::all_of(extOp->getUsers(), llvm::IsaPred<vector::ContractionOp>);
239}
240
241static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp) { return true; }
242
243/// Return the MMA elementwise enum associated with `op` if it is supported.
244/// Return `std::nullopt` otherwise.
245static std::optional<gpu::MMAElementwiseOp>
247 if (isa<arith::AddFOp>(op))
248 return gpu::MMAElementwiseOp::ADDF;
249 if (isa<arith::MulFOp>(op))
250 return gpu::MMAElementwiseOp::MULF;
251 if (isa<arith::SubFOp>(op))
252 return gpu::MMAElementwiseOp::SUBF;
253 if (isa<arith::MaximumFOp>(op))
254 return gpu::MMAElementwiseOp::MAXF;
255 if (isa<arith::MinimumFOp>(op))
256 return gpu::MMAElementwiseOp::MINF;
257 if (isa<arith::DivFOp>(op))
258 return gpu::MMAElementwiseOp::DIVF;
259 if (isa<arith::AddIOp>(op))
260 return gpu::MMAElementwiseOp::ADDI;
261 if (isa<arith::MulIOp>(op))
262 return gpu::MMAElementwiseOp::MULI;
263 if (isa<arith::SubIOp>(op))
264 return gpu::MMAElementwiseOp::SUBI;
265 if (isa<arith::DivSIOp>(op))
266 return gpu::MMAElementwiseOp::DIVS;
267 if (isa<arith::DivUIOp>(op))
268 return gpu::MMAElementwiseOp::DIVU;
269 if (isa<arith::NegFOp>(op))
270 return gpu::MMAElementwiseOp::NEGATEF;
271 if (isa<arith::ExtFOp>(op))
272 return gpu::MMAElementwiseOp::EXTF;
273 return std::nullopt;
274}
275
276/// Return true if the op is supported as elementwise op on MMAMatrix type.
278 return convertElementwiseOpToMMA(op).has_value();
279}
280
281/// Returns true if the extract strided slice op is supported with `mma.sync`
282/// path.
283static bool
284extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op) {
285
286 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
288 if (failed(warpMatrixInfo))
289 return false;
290
291 FailureOr<vector::ContractionOp> contractOp = nvgpu::getUserContract(op);
292 if (failed(contractOp))
293 return false;
294
295 // Handle vector.extract_strided_slice on registers containing
296 // matrixB and matrixC operands. vector.extract_strided_slice op
297 // is not supported on registers containing matrixA operands.
298 if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B)
299 return (cast<VectorType>(op->getResult(0).getType()) ==
300 cast<VectorType>((*contractOp).getRhs().getType()));
301 if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::C)
302 return (cast<VectorType>(op->getResult(0).getType()) ==
303 cast<VectorType>((*contractOp).getAcc().getType()));
304
305 return false;
306}
307
308static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
309 if (isa<scf::ForOp, scf::YieldOp>(op))
310 return true;
311 if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
312 return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferRead)
313 : transferReadSupportsMMAMatrixType(transferRead);
314 if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
315 return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferWrite)
316 : transferWriteSupportsMMAMatrixType(transferWrite);
317 if (auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op))
318 return useNvGpu &&
319 extractStridedSliceSupportsMMAMatrixType(extractStridedSlice);
320 if (auto contract = dyn_cast<vector::ContractionOp>(op))
321 return contractSupportsMMAMatrixType(contract, useNvGpu);
322 if (auto constant = dyn_cast<arith::ConstantOp>(op))
323 return constantSupportsMMAMatrixType(constant);
324 if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
326 if (auto signedExtend = dyn_cast<arith::ExtSIOp>(op))
328 if (auto unsignedExtend = dyn_cast<arith::ExtUIOp>(op))
330 if (auto fpExtend = dyn_cast<arith::ExtFOp>(op))
331 return fpExtendSupportsMMAMatrixType(fpExtend);
333}
334
335/// Return an unsorted slice handling scf.for region differently than
336/// `getSlice`. In scf.for we only want to include as part of the slice elements
337/// that are part of the use/def chain.
340 const BackwardSliceOptions &backwardSliceOptions,
341 const ForwardSliceOptions &forwardSliceOptions) {
343 slice.insert(op);
344 unsigned currentIndex = 0;
345 SetVector<Operation *> backwardSlice;
346 SetVector<Operation *> forwardSlice;
347 while (currentIndex != slice.size()) {
348 auto *currentOp = (slice)[currentIndex];
349 // Compute and insert the backwardSlice starting from currentOp.
350 backwardSlice.clear();
351 LogicalResult result =
352 getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
353 assert(result.succeeded() && "expected a backward slice");
354 (void)result;
355 slice.insert_range(backwardSlice);
356
357 // Compute and insert the forwardSlice starting from currentOp.
358 forwardSlice.clear();
359 // Special case for ForOp, we don't want to include the whole region but
360 // only the value using the region arguments.
361 // TODO: We should refine this to only care about the region arguments being
362 // converted to matrix type.
363 if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) {
364 for (Value forOpResult : forOp.getResults())
365 getForwardSlice(forOpResult, &forwardSlice, forwardSliceOptions);
366 for (BlockArgument &arg : forOp.getRegionIterArgs())
367 getForwardSlice(arg, &forwardSlice, forwardSliceOptions);
368 } else {
369 getForwardSlice(currentOp, &forwardSlice, forwardSliceOptions);
370 }
371 slice.insert_range(forwardSlice);
372 ++currentIndex;
373 }
374 return slice;
375}
376
377// Analyze slice of operations based on convert op to figure out if the whole
378// slice can be converted to MMA operations.
380 bool useNvGpu) {
381 auto hasVectorDest = [](Operation *op) {
382 return llvm::any_of(op->getResultTypes(), llvm::IsaPred<VectorType>);
383 };
384 BackwardSliceOptions backwardSliceOptions;
385 backwardSliceOptions.filter = hasVectorDest;
386
387 auto hasVectorSrc = [](Operation *op) {
388 return llvm::any_of(op->getOperandTypes(), llvm::IsaPred<VectorType>);
389 };
390 ForwardSliceOptions forwardSliceOptions;
391 forwardSliceOptions.filter = hasVectorSrc;
392
393 SetVector<Operation *> opToConvert;
394 op->walk([&](Operation *nestedOp) {
395 if (!isa<vector::ContractionOp>(nestedOp) &&
397 return;
398 if (opToConvert.contains(nestedOp))
399 return;
400 SetVector<Operation *> dependentOps =
401 getSliceContract(nestedOp, backwardSliceOptions, forwardSliceOptions);
402 // If any instruction cannot use MMA matrix type drop the whole
403 // chain. MMA matrix are stored in an opaque type so they cannot be used
404 // by all operations.
405 if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) {
406 if (!supportsMMaMatrixType(op, useNvGpu)) {
407 LDBG() << "cannot convert op: " << *op;
408 return true;
409 }
410 return false;
411 }))
412 return;
413
414 opToConvert.insert_range(dependentOps);
415 });
416 // Sort the operations so that we can convert them in topological order.
417 return topologicalSort(opToConvert);
418}
419
420namespace {
421// Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted
422// to MMA matmul.
423struct PrepareContractToGPUMMA
424 : public OpRewritePattern<vector::ContractionOp> {
425 using Base::Base;
426
427 LogicalResult matchAndRewrite(vector::ContractionOp op,
428 PatternRewriter &rewriter) const override {
429 Location loc = op.getLoc();
430 Value lhs = op.getLhs(), rhs = op.getRhs(), res = op.getAcc();
431
432 // Set up the parallel/reduction structure in right form.
433 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
434 auto infer = [&](MapList m) {
435 return AffineMap::inferFromExprList(m, op.getContext());
436 };
437 AffineExpr m, n, k;
438 bindDims(rewriter.getContext(), m, n, k);
439 static constexpr std::array<int64_t, 2> perm = {1, 0};
440 auto iteratorTypes = op.getIteratorTypes().getValue();
441 SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
442 if (!(vector::isParallelIterator(iteratorTypes[0]) &&
443 vector::isParallelIterator(iteratorTypes[1]) &&
444 vector::isReductionIterator(iteratorTypes[2])))
445 return rewriter.notifyMatchFailure(op, "not a gemm contraction");
446 //
447 // Two outer parallel, one inner reduction (matmat flavor).
448 //
449 // This is the classical row-major matmul, nothing to do.
450 if (maps == infer({{m, k}, {k, n}, {m, n}}))
451 return rewriter.notifyMatchFailure(op, "contraction already prepared");
452 if (maps == infer({{m, k}, {n, k}, {m, n}})) {
453 rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
454 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
455 lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
456 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
457 rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
458 lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
459 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
460 std::swap(rhs, lhs);
461 rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
462 lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
463 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
464 std::swap(rhs, lhs);
465 rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
466 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
467 std::swap(lhs, rhs);
468 lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
469 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
470 std::swap(lhs, rhs);
471 } else {
472 // TODO: llvm_unreachable ?
473 return rewriter.notifyMatchFailure(op, "unexpected contraction case");
474 }
475 rewriter.replaceOpWithNewOp<vector::ContractionOp>(
476 op, lhs, rhs, res,
477 rewriter.getAffineMapArrayAttr(infer({{m, k}, {k, n}, {m, n}})),
478 op.getIteratorTypes());
479 return success();
480 }
481};
482
483// Fold transpose op into the transfer read op. NVGPU mma.sync op only supports
484// row-, column-, and row-major layout for matrixA, matrixB, and matrixC,
485// respectively. We can fold the transpose operation when loading the data from
486// Shared Memory to registers.
487struct CombineTransferReadOpTranspose final
488 : public OpRewritePattern<vector::TransposeOp> {
489 using Base::Base;
490
491 LogicalResult matchAndRewrite(vector::TransposeOp op,
492 PatternRewriter &rewriter) const override {
493 // Look through integer extend ops.
494 Value source = op.getVector();
495 Type resultType = op.getType();
496 Operation *extOp;
497 if ((extOp = source.getDefiningOp<arith::ExtSIOp>()) ||
498 (extOp = source.getDefiningOp<arith::ExtUIOp>()) ||
499 (extOp = source.getDefiningOp<arith::ExtFOp>())) {
500 source = extOp->getOperand(0);
501 resultType =
502 VectorType::get(cast<VectorType>(resultType).getShape(),
503 cast<VectorType>(source.getType()).getElementType());
504 }
505
506 auto transferReadOp = source.getDefiningOp<vector::TransferReadOp>();
507 if (!transferReadOp)
508 return rewriter.notifyMatchFailure(op, "no transfer read");
509
510 // TODO: support 0-d corner case.
511 if (transferReadOp.getTransferRank() == 0)
512 return rewriter.notifyMatchFailure(op, "0-D transfer read");
513
514 if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim())
515 return rewriter.notifyMatchFailure(op, "not inbounds transfer read");
516
517 AffineMap permutationMap =
518 AffineMap::getPermutationMap(op.getPermutation(), op.getContext());
519 AffineMap newMap =
520 permutationMap.compose(transferReadOp.getPermutationMap());
521
522 auto loc = op.getLoc();
523 Value result = vector::TransferReadOp::create(
524 rewriter, loc, resultType, transferReadOp.getBase(),
525 transferReadOp.getIndices(), AffineMapAttr::get(newMap),
526 transferReadOp.getPadding(), transferReadOp.getMask(),
527 transferReadOp.getInBoundsAttr())
528 .getResult();
529
530 // Fuse through the integer extend op.
531 if (extOp) {
532 if (isa<arith::ExtSIOp>(extOp))
533 result = arith::ExtSIOp::create(rewriter, loc, op.getType(), result)
534 .getResult();
535 else if (isa<arith::ExtUIOp>(extOp))
536 result = arith::ExtUIOp::create(rewriter, loc, op.getType(), result)
537 .getResult();
538 else
539 result = arith::ExtFOp::create(rewriter, loc, op.getType(), result)
540 .getResult();
541 }
542
543 rewriter.replaceOp(op, result);
544 return success();
545 }
546};
547
548} // namespace
549
550// MMA types have different layout based on how they are used in matmul ops.
551// Figure the right layout to use by looking at op uses.
552// TODO: Change the GPU dialect to abstract the layout at the this level and
553// only care about it during lowering to NVVM.
554static const char *inferFragType(Operation *op) {
555 // We can have arith.ext ops before reaching contract ops. See through them
556 // and other kinds of elementwise ops.
557 if (op->hasOneUse()) {
558 Operation *userOp = *op->user_begin();
559 if (userOp->hasTrait<OpTrait::Elementwise>())
560 return inferFragType(userOp);
561 }
562
563 for (Operation *users : op->getUsers()) {
564 auto contract = dyn_cast<vector::ContractionOp>(users);
565 if (!contract)
566 continue;
567 assert(op->getNumResults() == 1);
568 if (contract.getLhs() == op->getResult(0))
569 return "AOp";
570 if (contract.getRhs() == op->getResult(0))
571 return "BOp";
572 }
573 return "COp";
574}
575
576static LogicalResult
577convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
578 llvm::DenseMap<Value, Value> &valueMapping) {
579 OpBuilder::InsertionGuard g(rewriter);
580 rewriter.setInsertionPoint(op);
581
582 assert(op.getTransferRank() > 0 && "unexpected 0-d transfer");
584 "expected convertible operation");
585
586 AffineMap permutationMap = op.getPermutationMap();
587 std::optional<int64_t> stride =
588 getStaticallyKnownRowStride(op.getShapedType(), permutationMap);
589 if (!stride.has_value()) {
590 LDBG() << "no stride";
591 return rewriter.notifyMatchFailure(op, "no stride");
592 }
593
594 // transferReadSupportsMMAMatrixType ensures that either of the map results is
595 // the most minor dimension. Under this constraint, whether the map represents
596 // a transposed view can be inferred from whether the first result is the most
597 // minor memref dimension.
598 const bool isTranspose = isFirstResultLastMapDimension(permutationMap);
599
600 Value mappingResult = op.getResult();
601 auto elType = op.getVectorType().getElementType();
602 const char *fragType = inferFragType(op);
603 if (op->hasOneUse()) {
604 auto *user = *op->user_begin();
605 // Infer the signedness of the mma type from the integer extend.
606 if (isa<arith::ExtSIOp, arith::ExtUIOp>(user)) {
607 elType = IntegerType::get(
608 op.getContext(), cast<IntegerType>(elType).getWidth(),
609 isa<arith::ExtSIOp>(user) ? IntegerType::Signed
610 : IntegerType::Unsigned);
611 mappingResult = user->getResult(0);
612 }
613 }
614 gpu::MMAMatrixType type =
615 gpu::MMAMatrixType::get(op.getVectorType().getShape(), elType, fragType);
616 Value load = gpu::SubgroupMmaLoadMatrixOp::create(
617 rewriter, op.getLoc(), type, op.getBase(), op.getIndices(),
618 rewriter.getIndexAttr(*stride),
619 isTranspose ? rewriter.getUnitAttr() : UnitAttr());
620 valueMapping[mappingResult] = load;
621
622 LDBG() << "transfer read to: " << load;
623 return success();
624}
625
626static LogicalResult
627convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
628 llvm::DenseMap<Value, Value> &valueMapping) {
629 OpBuilder::InsertionGuard g(rewriter);
630 rewriter.setInsertionPoint(op);
631
633 std::optional<int64_t> stride =
634 getStaticallyKnownRowStride(op.getShapedType(), op.getPermutationMap());
635 if (!stride.has_value()) {
636 LDBG() << "no stride";
637 return rewriter.notifyMatchFailure(op, "no stride");
638 }
639
640 auto it = valueMapping.find(op.getVector());
641 if (it == valueMapping.end()) {
642 LDBG() << "no mapping";
643 return rewriter.notifyMatchFailure(op, "no mapping");
644 }
645
646 Value matrix = it->second;
647 auto store = gpu::SubgroupMmaStoreMatrixOp::create(
648 rewriter, op.getLoc(), matrix, op.getBase(), op.getIndices(),
649 rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr());
650 (void)store;
651
652 LDBG() << "transfer write to: " << store;
653
654 LDBG() << "erase: " << op;
655 rewriter.eraseOp(op);
656 return success();
657}
658
659/// Returns the vector type which represents a matrix fragment.
660static VectorType
661getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo &regInfo) {
662 SmallVector<int64_t> shape{regInfo.numRegistersPerFragment,
663 regInfo.elementsPerRegister};
664 Type elType = regInfo.registerLLVMType;
665 if (auto vecType = dyn_cast<VectorType>(elType))
666 elType = vecType.getElementType();
667 return VectorType::get(shape, elType);
668}
669
670/// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
671static LogicalResult
672convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op,
673 llvm::DenseMap<Value, Value> &valueMapping) {
674 OpBuilder::InsertionGuard g(rewriter);
675 rewriter.setInsertionPoint(op);
676
677 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
679 if (failed(warpMatrixInfo)) {
680 LDBG() << "no warpMatrixInfo";
681 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
682 }
683
684 FailureOr<nvgpu::FragmentElementInfo> regInfo =
685 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
686 if (failed(regInfo)) {
687 LDBG() << "not mma sync reg info";
688 return rewriter.notifyMatchFailure(op, "not mma sync reg info");
689 }
690
691 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
692 auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
693 if (!dense) {
694 LDBG() << "not a splat";
695 return rewriter.notifyMatchFailure(op, "not a splat");
696 }
697
698 Value result = arith::ConstantOp::create(
699 rewriter, op.getLoc(), vectorType,
700 DenseElementsAttr::get(vectorType, dense.getSplatValue<Attribute>()));
701 valueMapping[op.getResult()] = result;
702 return success();
703}
704
705/// Check if the loaded matrix operand requires transposed.
706/// Transposed Map Example:
707/// Example 1 : (..., d0, d1) -> (d1 * 1, d0 * 2)
708/// Example 2 : (d0, d1, d2, d3) -> (d3, d2)
709/// The code below checks if the output 2D is transposed using a generalized
710/// version : (d0, d1, dn, ..., dm, ...) -> (dm, dn)
711/// Returns : true; if m > n, false o.w.
712static FailureOr<bool> isTransposed(vector::TransferReadOp op) {
714
715 if (map.getNumResults() != 2) {
716 LDBG() << "Failed because the result of `vector.transfer_read` "
717 "is not a 2d operand";
718 return failure();
719 }
720
721 // Output 2D matrix dimensions in the order of d0, d1.
722 mlir::AffineExpr dM = map.getResult(0);
723 mlir::AffineExpr dN = map.getResult(1);
724
725 // Find the position of these expressions in the input.
726 auto exprM = dyn_cast<AffineDimExpr>(dM);
727 auto exprN = dyn_cast<AffineDimExpr>(dN);
728
729 if (!exprM || !exprN) {
730 LDBG() << "Failed because expressions are not affine dim "
731 "expressions, then transpose cannot be determined.";
732 return failure();
733 }
734
735 return exprM.getPosition() > exprN.getPosition();
736}
737
738static LogicalResult
739creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
740 llvm::DenseMap<Value, Value> &valueMapping) {
741 OpBuilder::InsertionGuard g(rewriter);
742 rewriter.setInsertionPoint(op);
743 Location loc = op->getLoc();
744
745 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
747 if (failed(warpMatrixInfo)) {
748 LDBG() << "no warpMatrixInfo";
749 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
750 }
751
752 FailureOr<nvgpu::FragmentElementInfo> regInfo =
753 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
754 if (failed(regInfo)) {
755 LDBG() << "not mma sync reg info";
756 return rewriter.notifyMatchFailure(op, "not mma sync reg info");
757 }
758
759 FailureOr<bool> transpose = isTransposed(op);
760 if (failed(transpose)) {
761 LDBG() << "failed to determine the transpose";
762 return rewriter.notifyMatchFailure(
763 op, "Op should likely not be converted to a nvgpu.ldmatrix call.");
764 }
765
766 FailureOr<nvgpu::LdMatrixParams> params =
767 nvgpu::getLdMatrixParams(*warpMatrixInfo, *transpose);
768
769 if (failed(params)) {
770 LDBG() << "failed to convert vector.transfer_read to ldmatrix. "
771 << "Op should likely not be converted to a nvgpu.ldmatrix call.";
772 return rewriter.notifyMatchFailure(
773 op, "failed to convert vector.transfer_read to ldmatrix; this op "
774 "likely should not be converted to a nvgpu.ldmatrix call.");
775 }
776
777 // Adjust the load offset.
778 auto laneId = gpu::LaneIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
779 FailureOr<AffineMap> offsets =
780 nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params);
781 if (failed(offsets)) {
782 LDBG() << "no offsets";
783 return rewriter.notifyMatchFailure(op, "no offsets");
784 }
785
786 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
787
789 getXferIndices<vector::TransferReadOp>(rewriter, op, *offsets, {laneId},
790 indices);
791
792 nvgpu::LdMatrixOp newOp =
793 nvgpu::LdMatrixOp::create(rewriter, loc, vectorType, op.getBase(),
794 indices, *transpose, params->numTiles);
795 valueMapping[op] = newOp->getResult(0);
796 return success();
797}
798
799static LogicalResult
800createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
801 llvm::DenseMap<Value, Value> &valueMapping) {
802 OpBuilder::InsertionGuard g(rewriter);
803 rewriter.setInsertionPoint(op);
804
805 Location loc = op.getLoc();
806 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
808 if (failed(warpMatrixInfo))
809 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
810 FailureOr<nvgpu::FragmentElementInfo> regInfo =
811 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
812 if (failed(regInfo)) {
813 return rewriter.notifyMatchFailure(
814 op, "Failed to deduce register fragment type during "
815 "conversion to distributed non-ldmatrix compatible load");
816 }
817
818 Value laneId = gpu::LaneIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
819
820 // This is the individual element type.
821 Type loadedElType = regInfo->registerLLVMType;
822 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
823
824 Value fill = arith::ConstantOp::create(
825 rewriter, op.getLoc(), vectorType.getElementType(),
826 rewriter.getZeroAttr(vectorType.getElementType()));
827 Value result =
828 vector::BroadcastOp::create(rewriter, op.getLoc(), vectorType, fill);
829
830 bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
831
832 // If we are not transposing, then we can use vectorized loads. Otherwise, we
833 // must load each element individually.
834 if (!isTransposeLoad) {
835 if (!isa<VectorType>(loadedElType)) {
836 loadedElType = VectorType::get({1}, loadedElType);
837 }
838
839 for (int i = 0; i < vectorType.getShape()[0]; i++) {
840 FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
841 rewriter, op.getLoc(), *warpMatrixInfo);
842 if (failed(coords))
843 return rewriter.notifyMatchFailure(op, "no coords");
844
845 Value logicalValueId = arith::ConstantOp::create(
846 rewriter, loc, rewriter.getIndexType(),
847 rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
848 SmallVector<Value, 4> newIndices;
850 rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
851
852 Value el = vector::LoadOp::create(rewriter, loc, loadedElType,
853 op.getBase(), newIndices);
854 result = vector::InsertOp::create(rewriter, loc, el, result, i);
855 }
856 } else {
857 if (auto vecType = dyn_cast<VectorType>(loadedElType)) {
858 loadedElType = vecType.getElementType();
859 }
860 for (int i = 0; i < vectorType.getShape()[0]; i++) {
861 for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1];
862 innerIdx++) {
863
864 Value logicalValueId = arith::ConstantOp::create(
865 rewriter, loc, rewriter.getIndexType(),
866 rewriter.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx));
867 FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
868 rewriter, op.getLoc(), *warpMatrixInfo);
869 if (failed(coords))
870 return rewriter.notifyMatchFailure(op, "no coords");
871
872 SmallVector<Value, 4> newIndices;
874 rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
875 Value el = memref::LoadOp::create(rewriter, op.getLoc(), loadedElType,
876 op.getBase(), newIndices);
877 result = vector::InsertOp::create(rewriter, op.getLoc(), el, result,
878 ArrayRef<int64_t>{i, innerIdx});
879 }
880 }
881 }
882
883 valueMapping[op.getResult()] = result;
884 return success();
885}
886
887/// Return true if this is a shared memory memref type.
888static bool isSharedMemory(MemRefType type) {
889 auto addressSpace =
890 dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
891 return addressSpace &&
892 addressSpace.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
893}
894
895/// Converts a `vector.transfer_read` operation directly to either a
896/// `vector.load` or a `nvgpu.ldmatrix` operation. This function should only be
897/// used when converting to `nvgpu.mma.sync` operations.
898static LogicalResult
899convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op,
900 llvm::DenseMap<Value, Value> &valueMapping) {
901 OpBuilder::InsertionGuard g(rewriter);
902 rewriter.setInsertionPoint(op);
903
904 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
906 if (failed(warpMatrixInfo))
907 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
908
909 bool isLdMatrixCompatible =
910 isSharedMemory(cast<MemRefType>(op.getBase().getType())) &&
911 nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128;
912
913 VectorType vecTy = op.getVectorType();
914 int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth();
915
916 // When we are transposing the B operand, ldmatrix will only work if we have
917 // at least 8 rows to read and the width to read for the transpose is 128
918 // bits.
919 if (!op.getPermutationMap().isMinorIdentity() &&
920 (bitWidth != 16 || vecTy.getDimSize(1) < 8 ||
921 vecTy.getDimSize(0) * bitWidth < 128))
922 isLdMatrixCompatible = false;
923
924 if (!isLdMatrixCompatible)
925 return createNonLdMatrixLoads(rewriter, op, valueMapping);
926
927 return creatLdMatrixCompatibleLoads(rewriter, op, valueMapping);
928}
929
930static LogicalResult
931convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
932 llvm::DenseMap<Value, Value> &valueMapping) {
933 OpBuilder::InsertionGuard g(rewriter);
934 rewriter.setInsertionPoint(op);
935
936 Location loc = op->getLoc();
937 auto it = valueMapping.find(op.getVector());
938 if (it == valueMapping.end())
939 return rewriter.notifyMatchFailure(op, "no mapping");
940 Value matrix = it->second;
941
942 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
944 if (failed(warpMatrixInfo))
945 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
946 FailureOr<nvgpu::FragmentElementInfo> regInfo =
947 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
948 if (failed(regInfo))
949 return rewriter.notifyMatchFailure(op, "not mma sync reg info");
950
951 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
952 Value laneId = gpu::LaneIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
953
954 for (unsigned i = 0; i < vectorType.getShape()[0]; i++) {
955 Value logicalValueId = arith::ConstantOp::create(
956 rewriter, loc, rewriter.getIndexType(),
957 rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
958 FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
959 rewriter, op.getLoc(), *warpMatrixInfo);
960 if (failed(coords))
961 return rewriter.notifyMatchFailure(op, "no coords");
962
963 Value el =
964 vector::ExtractOp::create(rewriter, loc, matrix, ArrayRef<int64_t>{i});
965 SmallVector<Value, 4> newIndices;
967 rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
968 vector::StoreOp::create(rewriter, loc, el, op.getBase(), newIndices);
969 }
970
971 LDBG() << "erase: " << op;
972 rewriter.eraseOp(op);
973 return success();
974}
975
977 SmallVectorImpl<int64_t> &results) {
978 for (auto attr : arrayAttr)
979 results.push_back(cast<IntegerAttr>(attr).getInt());
980}
981
982static LogicalResult
984 vector::ExtractStridedSliceOp op,
985 llvm::DenseMap<Value, Value> &valueMapping) {
986 OpBuilder::InsertionGuard g(rewriter);
987 rewriter.setInsertionPoint(op);
988
989 Location loc = op->getLoc();
990
991 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
993 if (failed(warpMatrixInfo))
994 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
995
996 FailureOr<nvgpu::FragmentElementInfo> mmaSyncFragmentInfo =
997 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
998 if (failed(mmaSyncFragmentInfo))
999 return rewriter.notifyMatchFailure(op, "no mmaSyncFragmentInfo");
1000
1001 // Find the vector.transer_read whose result vector is being sliced.
1002 auto transferReadOp = op.getSource().getDefiningOp<vector::TransferReadOp>();
1003 if (!transferReadOp)
1004 return rewriter.notifyMatchFailure(op, "no transfer read");
1005
1006 warpMatrixInfo = nvgpu::getWarpMatrixInfo(transferReadOp);
1007 if (failed(warpMatrixInfo))
1008 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
1009
1010 FailureOr<nvgpu::FragmentElementInfo> ldFragmentInfo =
1011 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
1012 if (failed(ldFragmentInfo))
1013 return rewriter.notifyMatchFailure(op, "no ldFragmentInfo");
1014
1015 assert(
1016 (mmaSyncFragmentInfo->elementsPerRegister ==
1017 ldFragmentInfo->elementsPerRegister) &&
1018 "Number of elements per register should be same for load and mma.sync");
1019
1020 // Create vector.extract_strided_slice op for thread-owned fragments.
1021 std::array<int64_t, 2> strides = {1,
1022 1}; // stride for extract slice is always 1.
1023 std::array<int64_t, 2> sliceShape = {
1024 mmaSyncFragmentInfo->numRegistersPerFragment,
1025 mmaSyncFragmentInfo->elementsPerRegister};
1026 auto it = valueMapping.find(transferReadOp);
1027 if (it == valueMapping.end())
1028 return rewriter.notifyMatchFailure(op, "no mapping");
1029 auto sourceVector = it->second;
1030
1031 // offset and sizes at warp-level of onwership.
1032 SmallVector<int64_t> offsets;
1033 populateFromInt64AttrArray(op.getOffsets(), offsets);
1034
1036 populateFromInt64AttrArray(op.getSizes(), sizes);
1037 ArrayRef<int64_t> warpVectorShape = op.getSourceVectorType().getShape();
1038
1039 // Compute offset in vector registers. Note that the mma.sync vector registers
1040 // are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector
1041 // registers can only be sliced along numberOfFragments, i.e., sliceOffset[0].
1042 std::array<int64_t, 2> sliceOffset = {0, 0};
1043
1044 if (offsets[0] && offsets[1])
1045 return op->emitError() << "Slicing fragments in 2D is not supported. ";
1046 if (offsets[0])
1047 sliceOffset[0] = (warpVectorShape[0] / offsets[0]);
1048 else if (offsets[1])
1049 sliceOffset[0] = (warpVectorShape[1] / offsets[1]);
1050
1051 Value newOp = vector::ExtractStridedSliceOp::create(
1052 rewriter, loc, sourceVector, sliceOffset, sliceShape, strides);
1053
1054 valueMapping[op] = newOp;
1055 return success();
1056}
1057
1058static LogicalResult
1059convertContractOp(RewriterBase &rewriter, vector::ContractionOp op,
1060 llvm::DenseMap<Value, Value> &valueMapping) {
1061 OpBuilder::InsertionGuard g(rewriter);
1062 rewriter.setInsertionPoint(op);
1063
1064 auto itA = valueMapping.find(op.getLhs());
1065 auto itB = valueMapping.find(op.getRhs());
1066 auto itC = valueMapping.find(op.getAcc());
1067 if (itA == valueMapping.end() || itB == valueMapping.end() ||
1068 itC == valueMapping.end())
1069 return rewriter.notifyMatchFailure(op, "no mapping");
1070 Value opA = itA->second, opB = itB->second, opC = itC->second;
1071 Value matmul = gpu::SubgroupMmaComputeOp::create(rewriter, op.getLoc(),
1072 opC.getType(), opA, opB, opC,
1073 /*a_transpose=*/UnitAttr(),
1074 /*b_transpose=*/UnitAttr());
1075 valueMapping[op.getResult()] = matmul;
1076 return success();
1077}
1078
1079static LogicalResult
1080convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op,
1081 llvm::DenseMap<Value, Value> &valueMapping) {
1082 OpBuilder::InsertionGuard g(rewriter);
1083 rewriter.setInsertionPoint(op);
1084
1085 auto itA = valueMapping.find(op.getLhs());
1086 auto itB = valueMapping.find(op.getRhs());
1087 auto itC = valueMapping.find(op.getAcc());
1088 if (itA == valueMapping.end() || itB == valueMapping.end() ||
1089 itC == valueMapping.end())
1090 return rewriter.notifyMatchFailure(op, "no mapping");
1091 Value opA = itA->second, opB = itB->second, opC = itC->second;
1092 int64_t m = cast<VectorType>(op.getLhs().getType()).getShape()[0];
1093 int64_t n = cast<VectorType>(op.getRhs().getType()).getShape()[0];
1094 int64_t k = cast<VectorType>(op.getLhs().getType()).getShape()[1];
1095 Value matmul = nvgpu::MmaSyncOp::create(rewriter, op.getLoc(), opA, opB, opC,
1096 rewriter.getI64ArrayAttr({m, n, k}));
1097 valueMapping[op.getResult()] = matmul;
1098 return success();
1099}
1100
1101/// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
1102static LogicalResult
1103convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op,
1104 llvm::DenseMap<Value, Value> &valueMapping) {
1105 OpBuilder::InsertionGuard g(rewriter);
1106 rewriter.setInsertionPoint(op);
1107
1109
1110 auto splat =
1111 cast<SplatElementsAttr>(op.getValue()).getSplatValue<TypedAttr>();
1112 auto scalarConstant =
1113 arith::ConstantOp::create(rewriter, op.getLoc(), splat.getType(), splat);
1114 const char *fragType = inferFragType(op);
1115 auto vecType = cast<VectorType>(op.getType());
1117 vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1118 auto matrix = gpu::SubgroupMmaConstantMatrixOp::create(rewriter, op.getLoc(),
1119 type, scalarConstant);
1120 valueMapping[op.getResult()] = matrix;
1121 return success();
1122}
1123
1124/// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
1125static LogicalResult
1126convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op,
1127 llvm::DenseMap<Value, Value> &valueMapping) {
1128 OpBuilder::InsertionGuard g(rewriter);
1129 rewriter.setInsertionPoint(op);
1130
1132
1133 const char *fragType = inferFragType(op);
1134 auto vecType = op.getResultVectorType();
1136 vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1137 auto matrix = gpu::SubgroupMmaConstantMatrixOp::create(rewriter, op.getLoc(),
1138 type, op.getSource());
1139 valueMapping[op.getResult()] = matrix;
1140 return success();
1141}
1142
1143// Replace ForOp with a new ForOp with extra operands. The YieldOp is not
1144// updated and needs to be updated separately for the loop to be correct.
1145static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
1146 scf::ForOp loop,
1147 ValueRange newInitArgs) {
1148 OpBuilder::InsertionGuard g(rewriter);
1149 rewriter.setInsertionPoint(loop);
1150
1151 // Create a new loop before the existing one, with the extra operands.
1152 rewriter.setInsertionPoint(loop);
1153 auto operands = llvm::to_vector<4>(loop.getInitArgs());
1154 llvm::append_range(operands, newInitArgs);
1155 scf::ForOp newLoop =
1156 scf::ForOp::create(rewriter, loop.getLoc(), loop.getLowerBound(),
1157 loop.getUpperBound(), loop.getStep(), operands);
1158 rewriter.eraseBlock(newLoop.getBody());
1159
1160 newLoop.getRegion().getBlocks().splice(
1161 newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks());
1162 for (Value operand : newInitArgs)
1163 newLoop.getBody()->addArgument(operand.getType(), operand.getLoc());
1164
1165 for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
1166 loop.getNumResults())))
1167 rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
1168
1169 LDBG() << "newLoop now: " << newLoop;
1170 LDBG() << "stripped scf.for: " << loop;
1171 LDBG() << "erase: " << loop;
1172
1173 rewriter.eraseOp(loop);
1174 return newLoop;
1175}
1176
1177static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op,
1178 llvm::DenseMap<Value, Value> &valueMapping) {
1179 OpBuilder::InsertionGuard g(rewriter);
1180 rewriter.setInsertionPoint(op);
1181
1182 SmallVector<Value> newOperands;
1184 for (const auto &operand : llvm::enumerate(op.getInitArgs())) {
1185 auto it = valueMapping.find(operand.value());
1186 if (it == valueMapping.end()) {
1187 LDBG() << "no value mapping for: " << operand.value();
1188 continue;
1189 }
1190 argMapping.push_back(std::make_pair(
1191 operand.index(), op.getInitArgs().size() + newOperands.size()));
1192 newOperands.push_back(it->second);
1193 }
1194
1195 scf::ForOp newForOp = replaceForOpWithNewSignature(rewriter, op, newOperands);
1196 Block &loopBody = *newForOp.getBody();
1197 for (auto mapping : argMapping) {
1198 valueMapping[newForOp.getResult(mapping.first)] =
1199 newForOp.getResult(mapping.second);
1200 valueMapping[loopBody.getArgument(mapping.first +
1201 newForOp.getNumInductionVars())] =
1202 loopBody.getArgument(mapping.second + newForOp.getNumInductionVars());
1203 }
1204
1205 LDBG() << "scf.for to: " << newForOp;
1206 return success();
1207}
1208
1209static LogicalResult
1210convertYieldOp(RewriterBase &rewriter, scf::YieldOp op,
1211 llvm::DenseMap<Value, Value> &valueMapping) {
1212 OpBuilder::InsertionGuard g(rewriter);
1213 rewriter.setInsertionPoint(op);
1214
1215 auto loop = cast<scf::ForOp>(op->getParentOp());
1216 auto yieldOperands = llvm::to_vector<4>(op.getOperands());
1217 for (const auto &operand : llvm::enumerate(op.getOperands())) {
1218 auto it = valueMapping.find(operand.value());
1219 if (it == valueMapping.end())
1220 continue;
1221 // Replace the yield of old value with the for op argument to make it easier
1222 // to remove the dead code.
1223 yieldOperands[operand.index()] = loop.getInitArgs()[operand.index()];
1224 yieldOperands.push_back(it->second);
1225 }
1226 scf::YieldOp::create(rewriter, op.getLoc(), yieldOperands);
1227
1228 LDBG() << "erase: " << op;
1229 rewriter.eraseOp(op);
1230 return success();
1231}
1232
1233/// Convert an elementwise op to the equivalent elementwise op on MMA matrix.
1234static LogicalResult
1236 gpu::MMAElementwiseOp opType,
1237 llvm::DenseMap<Value, Value> &valueMapping) {
1238 OpBuilder::InsertionGuard g(rewriter);
1239 rewriter.setInsertionPoint(op);
1240
1241 SmallVector<Value> matrixOperands;
1242 for (Value operand : op->getOperands()) {
1243 auto it = valueMapping.find(operand);
1244 if (it == valueMapping.end())
1245 return rewriter.notifyMatchFailure(op, "no mapping");
1246 matrixOperands.push_back(it->second);
1247 }
1248 auto resultType = cast<gpu::MMAMatrixType>(matrixOperands[0].getType());
1249 if (opType == gpu::MMAElementwiseOp::EXTF) {
1250 // The floating point extension case has a different result type.
1251 auto vectorType = cast<VectorType>(op->getResultTypes()[0]);
1252 resultType = gpu::MMAMatrixType::get(resultType.getShape(),
1253 vectorType.getElementType(),
1254 resultType.getOperand());
1255 }
1256
1257 Value newOp = gpu::SubgroupMmaElementwiseOp::create(
1258 rewriter, op->getLoc(), resultType, matrixOperands, opType);
1259 valueMapping[op->getResult(0)] = newOp;
1260 return success();
1261}
1262
1264 bool useNvGpu) {
1265 if (!useNvGpu) {
1266 patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>(
1267 patterns.getContext());
1268 return;
1269 }
1271 patterns.add<CombineTransferReadOpTranspose>(patterns.getContext());
1272}
1273
1275 Operation *rootOp) {
1276 SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/false);
1277 llvm::DenseMap<Value, Value> valueMapping;
1278
1279 auto globalRes = LogicalResult::success();
1280 for (Operation *op : ops) {
1281 LDBG() << "Process op: " << *op;
1282 // Apparently callers do not want to early exit on failure here.
1283 auto res = LogicalResult::success();
1284 if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
1285 res = convertTransferReadOp(rewriter, transferRead, valueMapping);
1286 } else if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) {
1287 res = convertTransferWriteOp(rewriter, transferWrite, valueMapping);
1288 } else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) {
1289 res = convertContractOp(rewriter, contractOp, valueMapping);
1290 } else if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
1291 res = convertConstantOp(rewriter, constantOp, valueMapping);
1292 } else if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) {
1293 res = convertBroadcastOp(rewriter, broadcastOp, valueMapping);
1294 } else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
1295 res = convertForOp(rewriter, forOp, valueMapping);
1296 } else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
1297 res = convertYieldOp(rewriter, yieldOp, valueMapping);
1298 } else if (auto elementwiseType = convertElementwiseOpToMMA(op)) {
1299 res = convertElementwiseOp(rewriter, op, *elementwiseType, valueMapping);
1300 }
1301 if (failed(res))
1302 globalRes = failure();
1303 }
1304 return globalRes;
1305}
1306
1308 Operation *rootOp) {
1309 SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/true);
1310 llvm::DenseMap<Value, Value> valueMapping;
1311 for (Operation *op : ops) {
1313 .Case([&](vector::TransferReadOp transferReadOp) {
1314 return convertTransferReadToLoads(rewriter, transferReadOp,
1315 valueMapping);
1316 })
1317 .Case([&](vector::TransferWriteOp transferWriteOp) {
1318 return convertTransferWriteToStores(rewriter, transferWriteOp,
1319 valueMapping);
1320 })
1321 .Case([&](vector::ExtractStridedSliceOp extractStridedSliceOp) {
1322 return convertExtractStridedSlice(rewriter, extractStridedSliceOp,
1323 valueMapping);
1324 })
1325 .Case([&](vector::ContractionOp contractionOp) {
1326 return convertContractOpToMmaSync(rewriter, contractionOp,
1327 valueMapping);
1328 })
1329 .Case([&](scf::ForOp forOp) {
1330 return convertForOp(rewriter, forOp, valueMapping);
1331 })
1332 .Case([&](scf::YieldOp yieldOp) {
1333 return convertYieldOp(rewriter, yieldOp, valueMapping);
1334 })
1335 .Case([&](arith::ConstantOp constOp) {
1336 return convertConstantOpMmaSync(rewriter, constOp, valueMapping);
1337 })
1338 .Default([&](Operation *op) {
1339 return op->emitError() << "unhandled vector to mma type: " << *op;
1340 })
1341 .failed()) {
1342 return op->emitOpError()
1343 << "failed to convert op during vector-to-nvgpu conversion";
1344 }
1345 }
1346 return success();
1347}
1348
1349namespace {
1350
1351struct ConvertVectorToGPUPass
1352 : public impl::ConvertVectorToGPUBase<ConvertVectorToGPUPass> {
1353
1354 explicit ConvertVectorToGPUPass(bool useNvGpu_) {
1355 useNvGpu.setValue(useNvGpu_);
1356 }
1357
1358 void runOnOperation() override {
1359 RewritePatternSet patterns(&getContext());
1360 populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue());
1361 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
1362 return signalPassFailure();
1363
1364 IRRewriter rewriter(&getContext());
1365 if (useNvGpu) {
1366 if (failed(
1367 convertVectorToNVVMCompatibleMMASync(rewriter, getOperation())))
1368 return signalPassFailure();
1369 return;
1370 }
1371 (void)convertVectorToMMAOps(rewriter, getOperation());
1372 }
1373};
1374
1375} // namespace
1376
1377std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass(bool useNvGpu) {
1378 return std::make_unique<ConvertVectorToGPUPass>(useNvGpu);
1379}
return success()
lhs
ArrayAttr()
b getContext())
auto load
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
static LogicalResult convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op, llvm::DenseMap< Value, Value > &valueMapping)
static std::optional< gpu::MMAElementwiseOp > convertElementwiseOpToMMA(Operation *op)
Return the MMA elementwise enum associated with op if it is supported.
static LogicalResult convertContractOp(RewriterBase &rewriter, vector::ContractionOp op, llvm::DenseMap< Value, Value > &valueMapping)
static const char * inferFragType(Operation *op)
static LogicalResult convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool isSharedMemory(MemRefType type)
Return true if this is a shared memory memref type.
static VectorType getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo &regInfo)
Returns the vector type which represents a matrix fragment.
static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp)
static bool supportsMMaMatrixType(Operation *op, bool useNvGpu)
static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp)
Return true if the constant is a splat to a 2D vector so that it can be converted to a MMA constant m...
static bool contractSupportsMMAMatrixType(vector::ContractionOp contract, bool useNvGpu)
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
static FailureOr< bool > isTransposed(vector::TransferReadOp op)
Check if the loaded matrix operand requires transposed.
static LogicalResult convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp)
Return true if this integer extend op can be folded into a contract op.
static LogicalResult convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
Converts a vector.transfer_read operation directly to either a vector.load or a nvgpu....
static LogicalResult convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
static LogicalResult creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, scf::ForOp loop, ValueRange newInitArgs)
static bool transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp)
static LogicalResult convertElementwiseOp(RewriterBase &rewriter, Operation *op, gpu::MMAElementwiseOp opType, llvm::DenseMap< Value, Value > &valueMapping)
Convert an elementwise op to the equivalent elementwise op on MMA matrix.
static bool isFirstResultLastMapDimension(AffineMap permutationMap)
static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp)
static bool extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op)
Returns true if the extract strided slice op is supported with mma.sync path.
static LogicalResult convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static bool elementwiseSupportsMMAMatrixType(Operation *op)
Return true if the op is supported as elementwise op on MMAMatrix type.
static LogicalResult convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static LogicalResult convertYieldOp(RewriterBase &rewriter, scf::YieldOp op, llvm::DenseMap< Value, Value > &valueMapping)
static SetVector< Operation * > getOpToConvert(mlir::Operation *op, bool useNvGpu)
static LogicalResult convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op, llvm::DenseMap< Value, Value > &valueMapping)
static std::optional< int64_t > getStaticallyKnownRowStride(ShapedType type, AffineMap permutationMap)
static SetVector< Operation * > getSliceContract(Operation *op, const BackwardSliceOptions &backwardSliceOptions, const ForwardSliceOptions &forwardSliceOptions)
Return an unsorted slice handling scf.for region differently than getSlice.
static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp, AffineMap offsetMap, ArrayRef< Value > dimValues, SmallVector< Value, 4 > &indices)
For a vector TransferOpType xferOp, an empty indices vector, and an AffineMap representing offsets to...
static LogicalResult convertExtractStridedSlice(RewriterBase &rewriter, vector::ExtractStridedSliceOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp)
Return true if this is a broadcast from scalar to a 2D vector.
static LogicalResult createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
MLIRContext * getContext() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
UnitAttr getUnitAttr()
Definition Builders.cpp:102
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:93
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:328
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:368
MLIRContext * getContext() const
Definition Builders.h:56
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:285
IndexType getIndexType()
Definition Builders.cpp:55
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition Builders.cpp:322
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition Operation.h:849
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
Definition Operation.h:397
result_type_range getResultTypes()
Definition Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
user_range getUsers()
Returns a range of all users.
Definition Operation.h:873
user_iterator user_begin()
Definition Operation.h:869
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition GPUDialect.h:131
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
FailureOr< vector::ContractionOp > getUserContract(Operation *op)
Returns the first user of the op that is vector.contract.
Definition MMAUtils.cpp:48
FailureOr< WarpMatrixInfo > getWarpMatrixInfo(Operation *op)
If op is a vector.transfer_write, return the WarpMatrixInfo for the vector operand.
Definition MMAUtils.cpp:56
bool canLowerToWarpMatrixOperation(vector::TransferWriteOp op)
Returns the number of bits in a single tile row.
Definition MMAUtils.cpp:296
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Definition VectorOps.h:156
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Definition VectorOps.h:151
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contract a, b, c with row-major matmul semantics to a contraction with M...
Include the generated interface declarations.
void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, bool useNvGpu=false)
Patterns to transform vector ops into a canonical form to convert to MMA matrix operations.
LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:123
SliceOptions ForwardSliceOptions
const FrozenRewritePatternSet & patterns
LogicalResult convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter, Operation *rootOp)
Convert vector ops ops nested under rootOp to vector and GPU operaitons compatible with the nvvm....
std::unique_ptr< Pass > createConvertVectorToGPUPass(bool useNvGpu=false)
Convert from vector to GPU ops.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult convertVectorToMMAOps(RewriterBase &rewriter, Operation *rootOp)
Convert vector ops to MMA matrix operations nested under rootOp.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This trait tags element-wise ops on vectors or tensors.
TransitiveFilter filter