MLIR 22.0.0git
VectorToGPU.cpp
Go to the documentation of this file.
1//===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements lowering of vector operations to GPU dialect ops.
10//
11//===----------------------------------------------------------------------===//
12
14
28#include "mlir/IR/Builders.h"
29#include "mlir/IR/Region.h"
30#include "mlir/Pass/Pass.h"
32#include "llvm/ADT/STLExtras.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/Support/DebugLog.h"
35
36#define DEBUG_TYPE "vector-to-gpu"
37
38namespace mlir {
39#define GEN_PASS_DEF_CONVERTVECTORTOGPU
40#include "mlir/Conversion/Passes.h.inc"
41} // namespace mlir
42
43using namespace mlir;
44
45/// For a vector TransferOpType `xferOp`, an empty `indices` vector, and an
46/// AffineMap representing offsets to apply to indices, the function fills
47/// `indices` with the original indices plus the offsets. The offsets are
48/// applied by taking into account the permutation map of the transfer op. If
49/// the `offsetMap` has dimension placeholders, those should be provided in
50/// `dimValues`.
51template <typename TransferOpType>
52static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp,
53 AffineMap offsetMap, ArrayRef<Value> dimValues,
55 indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end());
56 Location loc = xferOp.getLoc();
57 unsigned offsetsIdx = 0;
58 for (auto expr : xferOp.getPermutationMap().getResults()) {
59 if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
60 Value prevIdx = indices[dim.getPosition()];
61 SmallVector<OpFoldResult, 3> dims(dimValues);
62 dims.push_back(prevIdx);
63 AffineExpr d0 = rewriter.getAffineDimExpr(offsetMap.getNumDims());
64 indices[dim.getPosition()] = affine::makeComposedAffineApply(
65 rewriter, loc, d0 + offsetMap.getResult(offsetsIdx++), dims);
66 continue;
67 }
68 }
69}
70
71// Return true if the contract op can be convert to MMA matmul.
72static bool contractSupportsMMAMatrixType(vector::ContractionOp contract,
73 bool useNvGpu) {
74 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
75 auto infer = [&](MapList m) {
76 return AffineMap::inferFromExprList(m, contract.getContext());
77 };
78 AffineExpr m, n, k;
79 bindDims(contract.getContext(), m, n, k);
80 auto iteratorTypes = contract.getIteratorTypes().getValue();
81 if (!(vector::isParallelIterator(iteratorTypes[0]) &&
82 vector::isParallelIterator(iteratorTypes[1]) &&
83 vector::isReductionIterator(iteratorTypes[2])))
84 return false;
85
86 // The contract needs to represent a matmul to be able to convert to
87 // MMAMatrix matmul.
88 if (!useNvGpu &&
89 contract.getIndexingMapsArray() != infer({{m, k}, {k, n}, {m, n}}))
90 return false;
91 if (useNvGpu &&
92 contract.getIndexingMapsArray() != infer({{m, k}, {n, k}, {m, n}}))
93 return false;
94
95 return true;
96}
97
98// Return true if the given map represents a transposed matrix load,
99// i.e. (d0, d1, ...) -> (dn-1, dn-2).
100static bool isTransposeMatrixLoadMap(AffineMap permutationMap) {
101 MLIRContext *ctx = permutationMap.getContext();
102 // Local OpBuilder is fine here, we just build attributes.
103 OpBuilder b(ctx);
104 auto nDim = permutationMap.getNumDims();
105 AffineExpr zero = b.getAffineConstantExpr(0);
106 if (nDim < 2) {
107 // Support transposed+broadcasted cases: affine_map<(d0) -> (d0, 0)>.
108 AffineExpr dim0 = b.getAffineDimExpr(0);
109 return permutationMap == AffineMap::get(1, 0, {dim0, zero}, ctx);
110 }
111
112 AffineExpr innerDim = b.getAffineDimExpr(nDim - 1);
113 AffineExpr outerDim = b.getAffineDimExpr(nDim - 2);
114 // Support both transposed and transposed+broadcasted cases.
115 return permutationMap == AffineMap::get(nDim, 0, {innerDim, outerDim}, ctx) ||
116 permutationMap == AffineMap::get(nDim, 0, {innerDim, zero}, ctx);
117}
118
119// Return the stide for the second-to-last dimension of |type| if it is a memref
120// and has a constant stride.
121static std::optional<int64_t> getStaticallyKnownRowStride(ShapedType type) {
122 auto memrefType = dyn_cast<MemRefType>(type);
123 if (!memrefType)
124 return false;
125 // If the memref is 0 or 1D the horizontal stride is 0.
126 if (memrefType.getRank() < 2)
127 return 0;
128 int64_t offset = 0;
130 if (failed(memrefType.getStridesAndOffset(strides, offset)) ||
131 strides.back() != 1)
132 return std::nullopt;
133 int64_t stride = strides[strides.size() - 2];
134 if (stride == ShapedType::kDynamic)
135 return std::nullopt;
136 return stride;
137}
138
139// Return true if the transfer op can be converted to a MMA matrix load.
140static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
141 if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
142 readOp.getVectorType().getRank() != 2)
143 return false;
144 if (!getStaticallyKnownRowStride(readOp.getShapedType()))
145 return false;
146
147 // Only allow integer types if the signedness can be inferred.
148 if (readOp.getVectorType().getElementType().isInteger(8))
149 if (!readOp->hasOneUse() || (!isa<arith::ExtSIOp>(*readOp->user_begin()) &&
150 !isa<arith::ExtUIOp>(*readOp->user_begin())))
151 return false;
152
153 AffineMap map = readOp.getPermutationMap();
154 MLIRContext *ctx = readOp.getContext();
155 AffineExpr innerDim = getAffineDimExpr(map.getNumDims() - 1, ctx);
156 AffineExpr zero = getAffineConstantExpr(0, ctx);
157 auto broadcastInnerDim =
158 AffineMap::get(map.getNumDims(), 0, {zero, innerDim}, ctx);
159 return map.isMinorIdentity() || map == broadcastInnerDim ||
161}
162
163// Return true if the transfer op can be converted to a MMA matrix store.
164static bool
165transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
166 // TODO: support 0-d corner case.
167 if (writeOp.getTransferRank() == 0)
168 return false;
169
170 if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() ||
171 writeOp.getVectorType().getRank() != 2)
172 return false;
173 if (!getStaticallyKnownRowStride(writeOp.getShapedType()))
174 return false;
175 // TODO: Support transpose once it is added to GPU dialect ops.
176 if (!writeOp.getPermutationMap().isMinorIdentity())
177 return false;
178 return true;
179}
180
181/// Return true if the constant is a splat to a 2D vector so that it can be
182/// converted to a MMA constant matrix op.
183static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) {
184 auto vecType = dyn_cast<VectorType>(constantOp.getType());
185 if (!vecType || vecType.getRank() != 2)
186 return false;
187 return isa<SplatElementsAttr>(constantOp.getValue());
188}
189
190/// Return true if this is a broadcast from scalar to a 2D vector.
191static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {
192 return broadcastOp.getResultVectorType().getRank() == 2;
193}
194
195/// Return true if this integer extend op can be folded into a contract op.
196template <typename ExtOpTy>
197static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) {
198 auto transferReadOp =
199 extOp.getOperand().template getDefiningOp<vector::TransferReadOp>();
200 if (!transferReadOp)
201 return false;
202 return llvm::all_of(extOp->getUsers(), llvm::IsaPred<vector::ContractionOp>);
203}
204
205static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp) { return true; }
206
207/// Return the MMA elementwise enum associated with `op` if it is supported.
208/// Return `std::nullopt` otherwise.
209static std::optional<gpu::MMAElementwiseOp>
211 if (isa<arith::AddFOp>(op))
212 return gpu::MMAElementwiseOp::ADDF;
213 if (isa<arith::MulFOp>(op))
214 return gpu::MMAElementwiseOp::MULF;
215 if (isa<arith::SubFOp>(op))
216 return gpu::MMAElementwiseOp::SUBF;
217 if (isa<arith::MaximumFOp>(op))
218 return gpu::MMAElementwiseOp::MAXF;
219 if (isa<arith::MinimumFOp>(op))
220 return gpu::MMAElementwiseOp::MINF;
221 if (isa<arith::DivFOp>(op))
222 return gpu::MMAElementwiseOp::DIVF;
223 if (isa<arith::AddIOp>(op))
224 return gpu::MMAElementwiseOp::ADDI;
225 if (isa<arith::MulIOp>(op))
226 return gpu::MMAElementwiseOp::MULI;
227 if (isa<arith::SubIOp>(op))
228 return gpu::MMAElementwiseOp::SUBI;
229 if (isa<arith::DivSIOp>(op))
230 return gpu::MMAElementwiseOp::DIVS;
231 if (isa<arith::DivUIOp>(op))
232 return gpu::MMAElementwiseOp::DIVU;
233 if (isa<arith::NegFOp>(op))
234 return gpu::MMAElementwiseOp::NEGATEF;
235 if (isa<arith::ExtFOp>(op))
236 return gpu::MMAElementwiseOp::EXTF;
237 return std::nullopt;
238}
239
240/// Return true if the op is supported as elementwise op on MMAMatrix type.
242 return convertElementwiseOpToMMA(op).has_value();
243}
244
245/// Returns true if the extract strided slice op is supported with `mma.sync`
246/// path.
247static bool
248extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op) {
249
250 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
252 if (failed(warpMatrixInfo))
253 return false;
254
255 FailureOr<vector::ContractionOp> contractOp = nvgpu::getUserContract(op);
256 if (failed(contractOp))
257 return false;
258
259 // Handle vector.extract_strided_slice on registers containing
260 // matrixB and matrixC operands. vector.extract_strided_slice op
261 // is not supported on registers containing matrixA operands.
262 if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B)
263 return (cast<VectorType>(op->getResult(0).getType()) ==
264 cast<VectorType>((*contractOp).getRhs().getType()));
265 if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::C)
266 return (cast<VectorType>(op->getResult(0).getType()) ==
267 cast<VectorType>((*contractOp).getAcc().getType()));
268
269 return false;
270}
271
272static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
273 if (isa<scf::ForOp, scf::YieldOp>(op))
274 return true;
275 if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
276 return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferRead)
277 : transferReadSupportsMMAMatrixType(transferRead);
278 if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
279 return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferWrite)
280 : transferWriteSupportsMMAMatrixType(transferWrite);
281 if (auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op))
282 return useNvGpu &&
283 extractStridedSliceSupportsMMAMatrixType(extractStridedSlice);
284 if (auto contract = dyn_cast<vector::ContractionOp>(op))
285 return contractSupportsMMAMatrixType(contract, useNvGpu);
286 if (auto constant = dyn_cast<arith::ConstantOp>(op))
287 return constantSupportsMMAMatrixType(constant);
288 if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
290 if (auto signedExtend = dyn_cast<arith::ExtSIOp>(op))
292 if (auto unsignedExtend = dyn_cast<arith::ExtUIOp>(op))
294 if (auto fpExtend = dyn_cast<arith::ExtFOp>(op))
295 return fpExtendSupportsMMAMatrixType(fpExtend);
297}
298
299/// Return an unsorted slice handling scf.for region differently than
300/// `getSlice`. In scf.for we only want to include as part of the slice elements
301/// that are part of the use/def chain.
304 const BackwardSliceOptions &backwardSliceOptions,
305 const ForwardSliceOptions &forwardSliceOptions) {
307 slice.insert(op);
308 unsigned currentIndex = 0;
309 SetVector<Operation *> backwardSlice;
310 SetVector<Operation *> forwardSlice;
311 while (currentIndex != slice.size()) {
312 auto *currentOp = (slice)[currentIndex];
313 // Compute and insert the backwardSlice starting from currentOp.
314 backwardSlice.clear();
315 LogicalResult result =
316 getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
317 assert(result.succeeded() && "expected a backward slice");
318 (void)result;
319 slice.insert_range(backwardSlice);
320
321 // Compute and insert the forwardSlice starting from currentOp.
322 forwardSlice.clear();
323 // Special case for ForOp, we don't want to include the whole region but
324 // only the value using the region arguments.
325 // TODO: We should refine this to only care about the region arguments being
326 // converted to matrix type.
327 if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) {
328 for (Value forOpResult : forOp.getResults())
329 getForwardSlice(forOpResult, &forwardSlice, forwardSliceOptions);
330 for (BlockArgument &arg : forOp.getRegionIterArgs())
331 getForwardSlice(arg, &forwardSlice, forwardSliceOptions);
332 } else {
333 getForwardSlice(currentOp, &forwardSlice, forwardSliceOptions);
334 }
335 slice.insert_range(forwardSlice);
336 ++currentIndex;
337 }
338 return slice;
339}
340
341// Analyze slice of operations based on convert op to figure out if the whole
342// slice can be converted to MMA operations.
344 bool useNvGpu) {
345 auto hasVectorDest = [](Operation *op) {
346 return llvm::any_of(op->getResultTypes(), llvm::IsaPred<VectorType>);
347 };
348 BackwardSliceOptions backwardSliceOptions;
349 backwardSliceOptions.filter = hasVectorDest;
350
351 auto hasVectorSrc = [](Operation *op) {
352 return llvm::any_of(op->getOperandTypes(), llvm::IsaPred<VectorType>);
353 };
354 ForwardSliceOptions forwardSliceOptions;
355 forwardSliceOptions.filter = hasVectorSrc;
356
357 SetVector<Operation *> opToConvert;
358 op->walk([&](Operation *nestedOp) {
359 if (!isa<vector::ContractionOp>(nestedOp) &&
361 return;
362 if (opToConvert.contains(nestedOp))
363 return;
364 SetVector<Operation *> dependentOps =
365 getSliceContract(nestedOp, backwardSliceOptions, forwardSliceOptions);
366 // If any instruction cannot use MMA matrix type drop the whole
367 // chain. MMA matrix are stored in an opaque type so they cannot be used
368 // by all operations.
369 if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) {
370 if (!supportsMMaMatrixType(op, useNvGpu)) {
371 LDBG() << "cannot convert op: " << *op;
372 return true;
373 }
374 return false;
375 }))
376 return;
377
378 opToConvert.insert_range(dependentOps);
379 });
380 // Sort the operations so that we can convert them in topological order.
381 return topologicalSort(opToConvert);
382}
383
384namespace {
385// Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted
386// to MMA matmul.
387struct PrepareContractToGPUMMA
388 : public OpRewritePattern<vector::ContractionOp> {
389 using Base::Base;
390
391 LogicalResult matchAndRewrite(vector::ContractionOp op,
392 PatternRewriter &rewriter) const override {
393 Location loc = op.getLoc();
394 Value lhs = op.getLhs(), rhs = op.getRhs(), res = op.getAcc();
395
396 // Set up the parallel/reduction structure in right form.
397 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
398 auto infer = [&](MapList m) {
399 return AffineMap::inferFromExprList(m, op.getContext());
400 };
401 AffineExpr m, n, k;
402 bindDims(rewriter.getContext(), m, n, k);
403 static constexpr std::array<int64_t, 2> perm = {1, 0};
404 auto iteratorTypes = op.getIteratorTypes().getValue();
405 SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
406 if (!(vector::isParallelIterator(iteratorTypes[0]) &&
407 vector::isParallelIterator(iteratorTypes[1]) &&
408 vector::isReductionIterator(iteratorTypes[2])))
409 return rewriter.notifyMatchFailure(op, "not a gemm contraction");
410 //
411 // Two outer parallel, one inner reduction (matmat flavor).
412 //
413 // This is the classical row-major matmul, nothing to do.
414 if (maps == infer({{m, k}, {k, n}, {m, n}}))
415 return rewriter.notifyMatchFailure(op, "contraction already prepared");
416 if (maps == infer({{m, k}, {n, k}, {m, n}})) {
417 rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
418 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
419 lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
420 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
421 rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
422 lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
423 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
424 std::swap(rhs, lhs);
425 rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
426 lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
427 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
428 std::swap(rhs, lhs);
429 rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
430 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
431 std::swap(lhs, rhs);
432 lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
433 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
434 std::swap(lhs, rhs);
435 } else {
436 // TODO: llvm_unreachable ?
437 return rewriter.notifyMatchFailure(op, "unexpected contraction case");
438 }
439 rewriter.replaceOpWithNewOp<vector::ContractionOp>(
440 op, lhs, rhs, res,
441 rewriter.getAffineMapArrayAttr(infer({{m, k}, {k, n}, {m, n}})),
442 op.getIteratorTypes());
443 return success();
444 }
445};
446
447// Fold transpose op into the transfer read op. NVGPU mma.sync op only supports
448// row-, column-, and row-major layout for matrixA, matrixB, and matrixC,
449// respectively. We can fold the transpose operation when loading the data from
450// Shared Memory to registers.
451struct CombineTransferReadOpTranspose final
452 : public OpRewritePattern<vector::TransposeOp> {
453 using Base::Base;
454
455 LogicalResult matchAndRewrite(vector::TransposeOp op,
456 PatternRewriter &rewriter) const override {
457 // Look through integer extend ops.
458 Value source = op.getVector();
459 Type resultType = op.getType();
460 Operation *extOp;
461 if ((extOp = source.getDefiningOp<arith::ExtSIOp>()) ||
462 (extOp = source.getDefiningOp<arith::ExtUIOp>()) ||
463 (extOp = source.getDefiningOp<arith::ExtFOp>())) {
464 source = extOp->getOperand(0);
465 resultType =
466 VectorType::get(cast<VectorType>(resultType).getShape(),
467 cast<VectorType>(source.getType()).getElementType());
468 }
469
470 auto transferReadOp = source.getDefiningOp<vector::TransferReadOp>();
471 if (!transferReadOp)
472 return rewriter.notifyMatchFailure(op, "no transfer read");
473
474 // TODO: support 0-d corner case.
475 if (transferReadOp.getTransferRank() == 0)
476 return rewriter.notifyMatchFailure(op, "0-D transfer read");
477
478 if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim())
479 return rewriter.notifyMatchFailure(op, "not inbounds transfer read");
480
481 AffineMap permutationMap =
482 AffineMap::getPermutationMap(op.getPermutation(), op.getContext());
483 AffineMap newMap =
484 permutationMap.compose(transferReadOp.getPermutationMap());
485
486 auto loc = op.getLoc();
487 Value result = vector::TransferReadOp::create(
488 rewriter, loc, resultType, transferReadOp.getBase(),
489 transferReadOp.getIndices(), AffineMapAttr::get(newMap),
490 transferReadOp.getPadding(), transferReadOp.getMask(),
491 transferReadOp.getInBoundsAttr())
492 .getResult();
493
494 // Fuse through the integer extend op.
495 if (extOp) {
496 if (isa<arith::ExtSIOp>(extOp))
497 result = arith::ExtSIOp::create(rewriter, loc, op.getType(), result)
498 .getResult();
499 else if (isa<arith::ExtUIOp>(extOp))
500 result = arith::ExtUIOp::create(rewriter, loc, op.getType(), result)
501 .getResult();
502 else
503 result = arith::ExtFOp::create(rewriter, loc, op.getType(), result)
504 .getResult();
505 }
506
507 rewriter.replaceOp(op, result);
508 return success();
509 }
510};
511
512} // namespace
513
514// MMA types have different layout based on how they are used in matmul ops.
515// Figure the right layout to use by looking at op uses.
516// TODO: Change the GPU dialect to abstract the layout at the this level and
517// only care about it during lowering to NVVM.
518static const char *inferFragType(Operation *op) {
519 // We can have arith.ext ops before reaching contract ops. See through them
520 // and other kinds of elementwise ops.
521 if (op->hasOneUse()) {
522 Operation *userOp = *op->user_begin();
523 if (userOp->hasTrait<OpTrait::Elementwise>())
524 return inferFragType(userOp);
525 }
526
527 for (Operation *users : op->getUsers()) {
528 auto contract = dyn_cast<vector::ContractionOp>(users);
529 if (!contract)
530 continue;
531 assert(op->getNumResults() == 1);
532 if (contract.getLhs() == op->getResult(0))
533 return "AOp";
534 if (contract.getRhs() == op->getResult(0))
535 return "BOp";
536 }
537 return "COp";
538}
539
540static LogicalResult
541convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
542 llvm::DenseMap<Value, Value> &valueMapping) {
543 OpBuilder::InsertionGuard g(rewriter);
544 rewriter.setInsertionPoint(op);
545
546 assert(op.getTransferRank() > 0 && "unexpected 0-d transfer");
548 "expected convertible operation");
549
550 std::optional<int64_t> stride =
551 getStaticallyKnownRowStride(op.getShapedType());
552 if (!stride.has_value()) {
553 LDBG() << "no stride";
554 return rewriter.notifyMatchFailure(op, "no stride");
555 }
556
557 AffineMap map = op.getPermutationMap();
558 bool isTranspose = isTransposeMatrixLoadMap(map);
559
560 // Handle broadcast by setting the stride to 0.
561 if (auto cstExpr = dyn_cast<AffineConstantExpr>(map.getResult(isTranspose))) {
562 assert(cstExpr.getValue() == 0);
563 stride = 0;
564 }
565
566 Value mappingResult = op.getResult();
567 auto elType = op.getVectorType().getElementType();
568 const char *fragType = inferFragType(op);
569 if (op->hasOneUse()) {
570 auto *user = *op->user_begin();
571 // Infer the signedness of the mma type from the integer extend.
572 if (isa<arith::ExtSIOp, arith::ExtUIOp>(user)) {
573 elType = IntegerType::get(
574 op.getContext(), cast<IntegerType>(elType).getWidth(),
575 isa<arith::ExtSIOp>(user) ? IntegerType::Signed
576 : IntegerType::Unsigned);
577 mappingResult = user->getResult(0);
578 }
579 }
580 gpu::MMAMatrixType type =
581 gpu::MMAMatrixType::get(op.getVectorType().getShape(), elType, fragType);
582 Value load = gpu::SubgroupMmaLoadMatrixOp::create(
583 rewriter, op.getLoc(), type, op.getBase(), op.getIndices(),
584 rewriter.getIndexAttr(*stride),
585 isTranspose ? rewriter.getUnitAttr() : UnitAttr());
586 valueMapping[mappingResult] = load;
587
588 LDBG() << "transfer read to: " << load;
589 return success();
590}
591
592static LogicalResult
593convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
594 llvm::DenseMap<Value, Value> &valueMapping) {
595 OpBuilder::InsertionGuard g(rewriter);
596 rewriter.setInsertionPoint(op);
597
599 std::optional<int64_t> stride =
600 getStaticallyKnownRowStride(op.getShapedType());
601 if (!stride.has_value()) {
602 LDBG() << "no stride";
603 return rewriter.notifyMatchFailure(op, "no stride");
604 }
605
606 auto it = valueMapping.find(op.getVector());
607 if (it == valueMapping.end()) {
608 LDBG() << "no mapping";
609 return rewriter.notifyMatchFailure(op, "no mapping");
610 }
611
612 Value matrix = it->second;
613 auto store = gpu::SubgroupMmaStoreMatrixOp::create(
614 rewriter, op.getLoc(), matrix, op.getBase(), op.getIndices(),
615 rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr());
616 (void)store;
617
618 LDBG() << "transfer write to: " << store;
619
620 LDBG() << "erase: " << op;
621 rewriter.eraseOp(op);
622 return success();
623}
624
625/// Returns the vector type which represents a matrix fragment.
626static VectorType
627getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo &regInfo) {
628 SmallVector<int64_t> shape{regInfo.numRegistersPerFragment,
629 regInfo.elementsPerRegister};
630 Type elType = regInfo.registerLLVMType;
631 if (auto vecType = dyn_cast<VectorType>(elType))
632 elType = vecType.getElementType();
633 return VectorType::get(shape, elType);
634}
635
636/// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
637static LogicalResult
638convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op,
639 llvm::DenseMap<Value, Value> &valueMapping) {
640 OpBuilder::InsertionGuard g(rewriter);
641 rewriter.setInsertionPoint(op);
642
643 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
645 if (failed(warpMatrixInfo)) {
646 LDBG() << "no warpMatrixInfo";
647 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
648 }
649
650 FailureOr<nvgpu::FragmentElementInfo> regInfo =
651 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
652 if (failed(regInfo)) {
653 LDBG() << "not mma sync reg info";
654 return rewriter.notifyMatchFailure(op, "not mma sync reg info");
655 }
656
657 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
658 auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
659 if (!dense) {
660 LDBG() << "not a splat";
661 return rewriter.notifyMatchFailure(op, "not a splat");
662 }
663
664 Value result = arith::ConstantOp::create(
665 rewriter, op.getLoc(), vectorType,
666 DenseElementsAttr::get(vectorType, dense.getSplatValue<Attribute>()));
667 valueMapping[op.getResult()] = result;
668 return success();
669}
670
671/// Check if the loaded matrix operand requires transposed.
672/// Transposed Map Example:
673/// Example 1 : (..., d0, d1) -> (d1 * 1, d0 * 2)
674/// Example 2 : (d0, d1, d2, d3) -> (d3, d2)
675/// The code below checks if the output 2D is transposed using a generalized
676/// version : (d0, d1, dn, ..., dm, ...) -> (dm, dn)
677/// Returns : true; if m > n, false o.w.
678static FailureOr<bool> isTransposed(vector::TransferReadOp op) {
680
681 if (map.getNumResults() != 2) {
682 LDBG() << "Failed because the result of `vector.transfer_read` "
683 "is not a 2d operand";
684 return failure();
685 }
686
687 // Output 2D matrix dimensions in the order of d0, d1.
688 mlir::AffineExpr dM = map.getResult(0);
689 mlir::AffineExpr dN = map.getResult(1);
690
691 // Find the position of these expressions in the input.
692 auto exprM = dyn_cast<AffineDimExpr>(dM);
693 auto exprN = dyn_cast<AffineDimExpr>(dN);
694
695 if (!exprM || !exprN) {
696 LDBG() << "Failed because expressions are not affine dim "
697 "expressions, then transpose cannot be determined.";
698 return failure();
699 }
700
701 return exprM.getPosition() > exprN.getPosition();
702}
703
704static LogicalResult
705creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
706 llvm::DenseMap<Value, Value> &valueMapping) {
707 OpBuilder::InsertionGuard g(rewriter);
708 rewriter.setInsertionPoint(op);
709 Location loc = op->getLoc();
710
711 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
713 if (failed(warpMatrixInfo)) {
714 LDBG() << "no warpMatrixInfo";
715 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
716 }
717
718 FailureOr<nvgpu::FragmentElementInfo> regInfo =
719 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
720 if (failed(regInfo)) {
721 LDBG() << "not mma sync reg info";
722 return rewriter.notifyMatchFailure(op, "not mma sync reg info");
723 }
724
725 FailureOr<bool> transpose = isTransposed(op);
726 if (failed(transpose)) {
727 LDBG() << "failed to determine the transpose";
728 return rewriter.notifyMatchFailure(
729 op, "Op should likely not be converted to a nvgpu.ldmatrix call.");
730 }
731
732 FailureOr<nvgpu::LdMatrixParams> params =
733 nvgpu::getLdMatrixParams(*warpMatrixInfo, *transpose);
734
735 if (failed(params)) {
736 LDBG() << "failed to convert vector.transfer_read to ldmatrix. "
737 << "Op should likely not be converted to a nvgpu.ldmatrix call.";
738 return rewriter.notifyMatchFailure(
739 op, "failed to convert vector.transfer_read to ldmatrix; this op "
740 "likely should not be converted to a nvgpu.ldmatrix call.");
741 }
742
743 // Adjust the load offset.
744 auto laneId = gpu::LaneIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
745 FailureOr<AffineMap> offsets =
746 nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params);
747 if (failed(offsets)) {
748 LDBG() << "no offsets";
749 return rewriter.notifyMatchFailure(op, "no offsets");
750 }
751
752 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
753
755 getXferIndices<vector::TransferReadOp>(rewriter, op, *offsets, {laneId},
756 indices);
757
758 nvgpu::LdMatrixOp newOp =
759 nvgpu::LdMatrixOp::create(rewriter, loc, vectorType, op.getBase(),
760 indices, *transpose, params->numTiles);
761 valueMapping[op] = newOp->getResult(0);
762 return success();
763}
764
765static LogicalResult
766createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
767 llvm::DenseMap<Value, Value> &valueMapping) {
768 OpBuilder::InsertionGuard g(rewriter);
769 rewriter.setInsertionPoint(op);
770
771 Location loc = op.getLoc();
772 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
774 if (failed(warpMatrixInfo))
775 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
776 FailureOr<nvgpu::FragmentElementInfo> regInfo =
777 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
778 if (failed(regInfo)) {
779 return rewriter.notifyMatchFailure(
780 op, "Failed to deduce register fragment type during "
781 "conversion to distributed non-ldmatrix compatible load");
782 }
783
784 Value laneId = gpu::LaneIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
785
786 // This is the individual element type.
787 Type loadedElType = regInfo->registerLLVMType;
788 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
789
790 Value fill = arith::ConstantOp::create(
791 rewriter, op.getLoc(), vectorType.getElementType(),
792 rewriter.getZeroAttr(vectorType.getElementType()));
793 Value result =
794 vector::BroadcastOp::create(rewriter, op.getLoc(), vectorType, fill);
795
796 bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
797
798 // If we are not transposing, then we can use vectorized loads. Otherwise, we
799 // must load each element individually.
800 if (!isTransposeLoad) {
801 if (!isa<VectorType>(loadedElType)) {
802 loadedElType = VectorType::get({1}, loadedElType);
803 }
804
805 for (int i = 0; i < vectorType.getShape()[0]; i++) {
806 FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
807 rewriter, op.getLoc(), *warpMatrixInfo);
808 if (failed(coords))
809 return rewriter.notifyMatchFailure(op, "no coords");
810
811 Value logicalValueId = arith::ConstantOp::create(
812 rewriter, loc, rewriter.getIndexType(),
813 rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
814 SmallVector<Value, 4> newIndices;
816 rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
817
818 Value el = vector::LoadOp::create(rewriter, loc, loadedElType,
819 op.getBase(), newIndices);
820 result = vector::InsertOp::create(rewriter, loc, el, result, i);
821 }
822 } else {
823 if (auto vecType = dyn_cast<VectorType>(loadedElType)) {
824 loadedElType = vecType.getElementType();
825 }
826 for (int i = 0; i < vectorType.getShape()[0]; i++) {
827 for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1];
828 innerIdx++) {
829
830 Value logicalValueId = arith::ConstantOp::create(
831 rewriter, loc, rewriter.getIndexType(),
832 rewriter.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx));
833 FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
834 rewriter, op.getLoc(), *warpMatrixInfo);
835 if (failed(coords))
836 return rewriter.notifyMatchFailure(op, "no coords");
837
838 SmallVector<Value, 4> newIndices;
840 rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
841 Value el = memref::LoadOp::create(rewriter, op.getLoc(), loadedElType,
842 op.getBase(), newIndices);
843 result = vector::InsertOp::create(rewriter, op.getLoc(), el, result,
844 ArrayRef<int64_t>{i, innerIdx});
845 }
846 }
847 }
848
849 valueMapping[op.getResult()] = result;
850 return success();
851}
852
853/// Return true if this is a shared memory memref type.
854static bool isSharedMemory(MemRefType type) {
855 auto addressSpace =
856 dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
857 return addressSpace &&
858 addressSpace.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
859}
860
861/// Converts a `vector.transfer_read` operation directly to either a
862/// `vector.load` or a `nvgpu.ldmatrix` operation. This function should only be
863/// used when converting to `nvgpu.mma.sync` operations.
864static LogicalResult
865convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op,
866 llvm::DenseMap<Value, Value> &valueMapping) {
867 OpBuilder::InsertionGuard g(rewriter);
868 rewriter.setInsertionPoint(op);
869
870 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
872 if (failed(warpMatrixInfo))
873 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
874
875 bool isLdMatrixCompatible =
876 isSharedMemory(cast<MemRefType>(op.getBase().getType())) &&
877 nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128;
878
879 VectorType vecTy = op.getVectorType();
880 int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth();
881
882 // When we are transposing the B operand, ldmatrix will only work if we have
883 // at least 8 rows to read and the width to read for the transpose is 128
884 // bits.
885 if (!op.getPermutationMap().isMinorIdentity() &&
886 (bitWidth != 16 || vecTy.getDimSize(1) < 8 ||
887 vecTy.getDimSize(0) * bitWidth < 128))
888 isLdMatrixCompatible = false;
889
890 if (!isLdMatrixCompatible)
891 return createNonLdMatrixLoads(rewriter, op, valueMapping);
892
893 return creatLdMatrixCompatibleLoads(rewriter, op, valueMapping);
894}
895
896static LogicalResult
897convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
898 llvm::DenseMap<Value, Value> &valueMapping) {
899 OpBuilder::InsertionGuard g(rewriter);
900 rewriter.setInsertionPoint(op);
901
902 Location loc = op->getLoc();
903 auto it = valueMapping.find(op.getVector());
904 if (it == valueMapping.end())
905 return rewriter.notifyMatchFailure(op, "no mapping");
906 Value matrix = it->second;
907
908 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
910 if (failed(warpMatrixInfo))
911 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
912 FailureOr<nvgpu::FragmentElementInfo> regInfo =
913 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
914 if (failed(regInfo))
915 return rewriter.notifyMatchFailure(op, "not mma sync reg info");
916
917 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
918 Value laneId = gpu::LaneIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
919
920 for (unsigned i = 0; i < vectorType.getShape()[0]; i++) {
921 Value logicalValueId = arith::ConstantOp::create(
922 rewriter, loc, rewriter.getIndexType(),
923 rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
924 FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
925 rewriter, op.getLoc(), *warpMatrixInfo);
926 if (failed(coords))
927 return rewriter.notifyMatchFailure(op, "no coords");
928
929 Value el =
930 vector::ExtractOp::create(rewriter, loc, matrix, ArrayRef<int64_t>{i});
931 SmallVector<Value, 4> newIndices;
933 rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
934 vector::StoreOp::create(rewriter, loc, el, op.getBase(), newIndices);
935 }
936
937 LDBG() << "erase: " << op;
938 rewriter.eraseOp(op);
939 return success();
940}
941
943 SmallVectorImpl<int64_t> &results) {
944 for (auto attr : arrayAttr)
945 results.push_back(cast<IntegerAttr>(attr).getInt());
946}
947
948static LogicalResult
950 vector::ExtractStridedSliceOp op,
951 llvm::DenseMap<Value, Value> &valueMapping) {
952 OpBuilder::InsertionGuard g(rewriter);
953 rewriter.setInsertionPoint(op);
954
955 Location loc = op->getLoc();
956
957 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
959 if (failed(warpMatrixInfo))
960 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
961
962 FailureOr<nvgpu::FragmentElementInfo> mmaSyncFragmentInfo =
963 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
964 if (failed(mmaSyncFragmentInfo))
965 return rewriter.notifyMatchFailure(op, "no mmaSyncFragmentInfo");
966
967 // Find the vector.transer_read whose result vector is being sliced.
968 auto transferReadOp = op.getSource().getDefiningOp<vector::TransferReadOp>();
969 if (!transferReadOp)
970 return rewriter.notifyMatchFailure(op, "no transfer read");
971
972 warpMatrixInfo = nvgpu::getWarpMatrixInfo(transferReadOp);
973 if (failed(warpMatrixInfo))
974 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
975
976 FailureOr<nvgpu::FragmentElementInfo> ldFragmentInfo =
977 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
978 if (failed(ldFragmentInfo))
979 return rewriter.notifyMatchFailure(op, "no ldFragmentInfo");
980
981 assert(
982 (mmaSyncFragmentInfo->elementsPerRegister ==
983 ldFragmentInfo->elementsPerRegister) &&
984 "Number of elements per register should be same for load and mma.sync");
985
986 // Create vector.extract_strided_slice op for thread-owned fragments.
987 std::array<int64_t, 2> strides = {1,
988 1}; // stride for extract slice is always 1.
989 std::array<int64_t, 2> sliceShape = {
990 mmaSyncFragmentInfo->numRegistersPerFragment,
991 mmaSyncFragmentInfo->elementsPerRegister};
992 auto it = valueMapping.find(transferReadOp);
993 if (it == valueMapping.end())
994 return rewriter.notifyMatchFailure(op, "no mapping");
995 auto sourceVector = it->second;
996
997 // offset and sizes at warp-level of onwership.
998 SmallVector<int64_t> offsets;
999 populateFromInt64AttrArray(op.getOffsets(), offsets);
1000
1002 populateFromInt64AttrArray(op.getSizes(), sizes);
1003 ArrayRef<int64_t> warpVectorShape = op.getSourceVectorType().getShape();
1004
1005 // Compute offset in vector registers. Note that the mma.sync vector registers
1006 // are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector
1007 // registers can only be sliced along numberOfFragments, i.e., sliceOffset[0].
1008 std::array<int64_t, 2> sliceOffset = {0, 0};
1009
1010 if (offsets[0] && offsets[1])
1011 return op->emitError() << "Slicing fragments in 2D is not supported. ";
1012 if (offsets[0])
1013 sliceOffset[0] = (warpVectorShape[0] / offsets[0]);
1014 else if (offsets[1])
1015 sliceOffset[0] = (warpVectorShape[1] / offsets[1]);
1016
1017 Value newOp = vector::ExtractStridedSliceOp::create(
1018 rewriter, loc, sourceVector, sliceOffset, sliceShape, strides);
1019
1020 valueMapping[op] = newOp;
1021 return success();
1022}
1023
1024static LogicalResult
1025convertContractOp(RewriterBase &rewriter, vector::ContractionOp op,
1026 llvm::DenseMap<Value, Value> &valueMapping) {
1027 OpBuilder::InsertionGuard g(rewriter);
1028 rewriter.setInsertionPoint(op);
1029
1030 auto itA = valueMapping.find(op.getLhs());
1031 auto itB = valueMapping.find(op.getRhs());
1032 auto itC = valueMapping.find(op.getAcc());
1033 if (itA == valueMapping.end() || itB == valueMapping.end() ||
1034 itC == valueMapping.end())
1035 return rewriter.notifyMatchFailure(op, "no mapping");
1036 Value opA = itA->second, opB = itB->second, opC = itC->second;
1037 Value matmul = gpu::SubgroupMmaComputeOp::create(rewriter, op.getLoc(),
1038 opC.getType(), opA, opB, opC,
1039 /*a_transpose=*/UnitAttr(),
1040 /*b_transpose=*/UnitAttr());
1041 valueMapping[op.getResult()] = matmul;
1042 return success();
1043}
1044
1045static LogicalResult
1046convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op,
1047 llvm::DenseMap<Value, Value> &valueMapping) {
1048 OpBuilder::InsertionGuard g(rewriter);
1049 rewriter.setInsertionPoint(op);
1050
1051 auto itA = valueMapping.find(op.getLhs());
1052 auto itB = valueMapping.find(op.getRhs());
1053 auto itC = valueMapping.find(op.getAcc());
1054 if (itA == valueMapping.end() || itB == valueMapping.end() ||
1055 itC == valueMapping.end())
1056 return rewriter.notifyMatchFailure(op, "no mapping");
1057 Value opA = itA->second, opB = itB->second, opC = itC->second;
1058 int64_t m = cast<VectorType>(op.getLhs().getType()).getShape()[0];
1059 int64_t n = cast<VectorType>(op.getRhs().getType()).getShape()[0];
1060 int64_t k = cast<VectorType>(op.getLhs().getType()).getShape()[1];
1061 Value matmul = nvgpu::MmaSyncOp::create(rewriter, op.getLoc(), opA, opB, opC,
1062 rewriter.getI64ArrayAttr({m, n, k}));
1063 valueMapping[op.getResult()] = matmul;
1064 return success();
1065}
1066
1067/// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
1068static LogicalResult
1069convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op,
1070 llvm::DenseMap<Value, Value> &valueMapping) {
1071 OpBuilder::InsertionGuard g(rewriter);
1072 rewriter.setInsertionPoint(op);
1073
1075
1076 auto splat =
1077 cast<SplatElementsAttr>(op.getValue()).getSplatValue<TypedAttr>();
1078 auto scalarConstant =
1079 arith::ConstantOp::create(rewriter, op.getLoc(), splat.getType(), splat);
1080 const char *fragType = inferFragType(op);
1081 auto vecType = cast<VectorType>(op.getType());
1083 vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1084 auto matrix = gpu::SubgroupMmaConstantMatrixOp::create(rewriter, op.getLoc(),
1085 type, scalarConstant);
1086 valueMapping[op.getResult()] = matrix;
1087 return success();
1088}
1089
1090/// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
1091static LogicalResult
1092convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op,
1093 llvm::DenseMap<Value, Value> &valueMapping) {
1094 OpBuilder::InsertionGuard g(rewriter);
1095 rewriter.setInsertionPoint(op);
1096
1098
1099 const char *fragType = inferFragType(op);
1100 auto vecType = op.getResultVectorType();
1102 vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1103 auto matrix = gpu::SubgroupMmaConstantMatrixOp::create(rewriter, op.getLoc(),
1104 type, op.getSource());
1105 valueMapping[op.getResult()] = matrix;
1106 return success();
1107}
1108
1109// Replace ForOp with a new ForOp with extra operands. The YieldOp is not
1110// updated and needs to be updated separately for the loop to be correct.
1111static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
1112 scf::ForOp loop,
1113 ValueRange newInitArgs) {
1114 OpBuilder::InsertionGuard g(rewriter);
1115 rewriter.setInsertionPoint(loop);
1116
1117 // Create a new loop before the existing one, with the extra operands.
1118 rewriter.setInsertionPoint(loop);
1119 auto operands = llvm::to_vector<4>(loop.getInitArgs());
1120 llvm::append_range(operands, newInitArgs);
1121 scf::ForOp newLoop =
1122 scf::ForOp::create(rewriter, loop.getLoc(), loop.getLowerBound(),
1123 loop.getUpperBound(), loop.getStep(), operands);
1124 rewriter.eraseBlock(newLoop.getBody());
1125
1126 newLoop.getRegion().getBlocks().splice(
1127 newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks());
1128 for (Value operand : newInitArgs)
1129 newLoop.getBody()->addArgument(operand.getType(), operand.getLoc());
1130
1131 for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
1132 loop.getNumResults())))
1133 rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
1134
1135 LDBG() << "newLoop now: " << newLoop;
1136 LDBG() << "stripped scf.for: " << loop;
1137 LDBG() << "erase: " << loop;
1138
1139 rewriter.eraseOp(loop);
1140 return newLoop;
1141}
1142
1143static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op,
1144 llvm::DenseMap<Value, Value> &valueMapping) {
1145 OpBuilder::InsertionGuard g(rewriter);
1146 rewriter.setInsertionPoint(op);
1147
1148 SmallVector<Value> newOperands;
1150 for (const auto &operand : llvm::enumerate(op.getInitArgs())) {
1151 auto it = valueMapping.find(operand.value());
1152 if (it == valueMapping.end()) {
1153 LDBG() << "no value mapping for: " << operand.value();
1154 continue;
1155 }
1156 argMapping.push_back(std::make_pair(
1157 operand.index(), op.getInitArgs().size() + newOperands.size()));
1158 newOperands.push_back(it->second);
1159 }
1160
1161 scf::ForOp newForOp = replaceForOpWithNewSignature(rewriter, op, newOperands);
1162 Block &loopBody = *newForOp.getBody();
1163 for (auto mapping : argMapping) {
1164 valueMapping[newForOp.getResult(mapping.first)] =
1165 newForOp.getResult(mapping.second);
1166 valueMapping[loopBody.getArgument(mapping.first +
1167 newForOp.getNumInductionVars())] =
1168 loopBody.getArgument(mapping.second + newForOp.getNumInductionVars());
1169 }
1170
1171 LDBG() << "scf.for to: " << newForOp;
1172 return success();
1173}
1174
1175static LogicalResult
1176convertYieldOp(RewriterBase &rewriter, scf::YieldOp op,
1177 llvm::DenseMap<Value, Value> &valueMapping) {
1178 OpBuilder::InsertionGuard g(rewriter);
1179 rewriter.setInsertionPoint(op);
1180
1181 auto loop = cast<scf::ForOp>(op->getParentOp());
1182 auto yieldOperands = llvm::to_vector<4>(op.getOperands());
1183 for (const auto &operand : llvm::enumerate(op.getOperands())) {
1184 auto it = valueMapping.find(operand.value());
1185 if (it == valueMapping.end())
1186 continue;
1187 // Replace the yield of old value with the for op argument to make it easier
1188 // to remove the dead code.
1189 yieldOperands[operand.index()] = loop.getInitArgs()[operand.index()];
1190 yieldOperands.push_back(it->second);
1191 }
1192 scf::YieldOp::create(rewriter, op.getLoc(), yieldOperands);
1193
1194 LDBG() << "erase: " << op;
1195 rewriter.eraseOp(op);
1196 return success();
1197}
1198
1199/// Convert an elementwise op to the equivalent elementwise op on MMA matrix.
1200static LogicalResult
1202 gpu::MMAElementwiseOp opType,
1203 llvm::DenseMap<Value, Value> &valueMapping) {
1204 OpBuilder::InsertionGuard g(rewriter);
1205 rewriter.setInsertionPoint(op);
1206
1207 SmallVector<Value> matrixOperands;
1208 for (Value operand : op->getOperands()) {
1209 auto it = valueMapping.find(operand);
1210 if (it == valueMapping.end())
1211 return rewriter.notifyMatchFailure(op, "no mapping");
1212 matrixOperands.push_back(it->second);
1213 }
1214 auto resultType = cast<gpu::MMAMatrixType>(matrixOperands[0].getType());
1215 if (opType == gpu::MMAElementwiseOp::EXTF) {
1216 // The floating point extension case has a different result type.
1217 auto vectorType = cast<VectorType>(op->getResultTypes()[0]);
1218 resultType = gpu::MMAMatrixType::get(resultType.getShape(),
1219 vectorType.getElementType(),
1220 resultType.getOperand());
1221 }
1222
1223 Value newOp = gpu::SubgroupMmaElementwiseOp::create(
1224 rewriter, op->getLoc(), resultType, matrixOperands, opType);
1225 valueMapping[op->getResult(0)] = newOp;
1226 return success();
1227}
1228
1230 bool useNvGpu) {
1231 if (!useNvGpu) {
1232 patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>(
1233 patterns.getContext());
1234 return;
1235 }
1237 patterns.add<CombineTransferReadOpTranspose>(patterns.getContext());
1238}
1239
1241 Operation *rootOp) {
1242 SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/false);
1243 llvm::DenseMap<Value, Value> valueMapping;
1244
1245 auto globalRes = LogicalResult::success();
1246 for (Operation *op : ops) {
1247 LDBG() << "Process op: " << *op;
1248 // Apparently callers do not want to early exit on failure here.
1249 auto res = LogicalResult::success();
1250 if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
1251 res = convertTransferReadOp(rewriter, transferRead, valueMapping);
1252 } else if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) {
1253 res = convertTransferWriteOp(rewriter, transferWrite, valueMapping);
1254 } else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) {
1255 res = convertContractOp(rewriter, contractOp, valueMapping);
1256 } else if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
1257 res = convertConstantOp(rewriter, constantOp, valueMapping);
1258 } else if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) {
1259 res = convertBroadcastOp(rewriter, broadcastOp, valueMapping);
1260 } else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
1261 res = convertForOp(rewriter, forOp, valueMapping);
1262 } else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
1263 res = convertYieldOp(rewriter, yieldOp, valueMapping);
1264 } else if (auto elementwiseType = convertElementwiseOpToMMA(op)) {
1265 res = convertElementwiseOp(rewriter, op, *elementwiseType, valueMapping);
1266 }
1267 if (failed(res))
1268 globalRes = failure();
1269 }
1270 return globalRes;
1271}
1272
1274 Operation *rootOp) {
1275 SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/true);
1276 llvm::DenseMap<Value, Value> valueMapping;
1277 for (Operation *op : ops) {
1279 .Case([&](vector::TransferReadOp transferReadOp) {
1280 return convertTransferReadToLoads(rewriter, transferReadOp,
1281 valueMapping);
1282 })
1283 .Case([&](vector::TransferWriteOp transferWriteOp) {
1284 return convertTransferWriteToStores(rewriter, transferWriteOp,
1285 valueMapping);
1286 })
1287 .Case([&](vector::ExtractStridedSliceOp extractStridedSliceOp) {
1288 return convertExtractStridedSlice(rewriter, extractStridedSliceOp,
1289 valueMapping);
1290 })
1291 .Case([&](vector::ContractionOp contractionOp) {
1292 return convertContractOpToMmaSync(rewriter, contractionOp,
1293 valueMapping);
1294 })
1295 .Case([&](scf::ForOp forOp) {
1296 return convertForOp(rewriter, forOp, valueMapping);
1297 })
1298 .Case([&](scf::YieldOp yieldOp) {
1299 return convertYieldOp(rewriter, yieldOp, valueMapping);
1300 })
1301 .Case([&](arith::ConstantOp constOp) {
1302 return convertConstantOpMmaSync(rewriter, constOp, valueMapping);
1303 })
1304 .Default([&](Operation *op) {
1305 return op->emitError() << "unhandled vector to mma type: " << *op;
1306 })
1307 .failed()) {
1308 return op->emitOpError()
1309 << "failed to convert op during vector-to-nvgpu conversion";
1310 }
1311 }
1312 return success();
1313}
1314
1315namespace {
1316
1317struct ConvertVectorToGPUPass
1318 : public impl::ConvertVectorToGPUBase<ConvertVectorToGPUPass> {
1319
1320 explicit ConvertVectorToGPUPass(bool useNvGpu_) {
1321 useNvGpu.setValue(useNvGpu_);
1322 }
1323
1324 void runOnOperation() override {
1325 RewritePatternSet patterns(&getContext());
1326 populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue());
1327 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
1328 return signalPassFailure();
1329
1330 IRRewriter rewriter(&getContext());
1331 if (useNvGpu) {
1332 if (failed(
1333 convertVectorToNVVMCompatibleMMASync(rewriter, getOperation())))
1334 return signalPassFailure();
1335 return;
1336 }
1337 (void)convertVectorToMMAOps(rewriter, getOperation());
1338 }
1339};
1340
1341} // namespace
1342
1343std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass(bool useNvGpu) {
1344 return std::make_unique<ConvertVectorToGPUPass>(useNvGpu);
1345}
return success()
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
auto load
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
static LogicalResult convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op, llvm::DenseMap< Value, Value > &valueMapping)
static std::optional< gpu::MMAElementwiseOp > convertElementwiseOpToMMA(Operation *op)
Return the MMA elementwise enum associated with op if it is supported.
static LogicalResult convertContractOp(RewriterBase &rewriter, vector::ContractionOp op, llvm::DenseMap< Value, Value > &valueMapping)
static const char * inferFragType(Operation *op)
static LogicalResult convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool isSharedMemory(MemRefType type)
Return true if this is a shared memory memref type.
static VectorType getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo &regInfo)
Returns the vector type which represents a matrix fragment.
static std::optional< int64_t > getStaticallyKnownRowStride(ShapedType type)
static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp)
static bool supportsMMaMatrixType(Operation *op, bool useNvGpu)
static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp)
Return true if the constant is a splat to a 2D vector so that it can be converted to a MMA constant m...
static bool contractSupportsMMAMatrixType(vector::ContractionOp contract, bool useNvGpu)
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
static FailureOr< bool > isTransposed(vector::TransferReadOp op)
Check if the loaded matrix operand requires transposed.
static bool isTransposeMatrixLoadMap(AffineMap permutationMap)
static LogicalResult convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp)
Return true if this integer extend op can be folded into a contract op.
static LogicalResult convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
Converts a vector.transfer_read operation directly to either a vector.load or a nvgpu....
static LogicalResult convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
static LogicalResult creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, scf::ForOp loop, ValueRange newInitArgs)
static bool transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp)
static LogicalResult convertElementwiseOp(RewriterBase &rewriter, Operation *op, gpu::MMAElementwiseOp opType, llvm::DenseMap< Value, Value > &valueMapping)
Convert an elementwise op to the equivalent elementwise op on MMA matrix.
static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp)
static bool extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op)
Returns true if the extract strided slice op is supported with mma.sync path.
static LogicalResult convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static bool elementwiseSupportsMMAMatrixType(Operation *op)
Return true if the op is supported as elementwise op on MMAMatrix type.
static LogicalResult convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static LogicalResult convertYieldOp(RewriterBase &rewriter, scf::YieldOp op, llvm::DenseMap< Value, Value > &valueMapping)
static SetVector< Operation * > getOpToConvert(mlir::Operation *op, bool useNvGpu)
static LogicalResult convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op, llvm::DenseMap< Value, Value > &valueMapping)
static SetVector< Operation * > getSliceContract(Operation *op, const BackwardSliceOptions &backwardSliceOptions, const ForwardSliceOptions &forwardSliceOptions)
Return an unsorted slice handling scf.for region differently than getSlice.
static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp, AffineMap offsetMap, ArrayRef< Value > dimValues, SmallVector< Value, 4 > &indices)
For a vector TransferOpType xferOp, an empty indices vector, and an AffineMap representing offsets to...
static LogicalResult convertExtractStridedSlice(RewriterBase &rewriter, vector::ExtractStridedSliceOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp)
Return true if this is a broadcast from scalar to a 2D vector.
static LogicalResult createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
MLIRContext * getContext() const
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:129
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
UnitAttr getUnitAttr()
Definition Builders.cpp:98
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:91
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:324
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:364
MLIRContext * getContext() const
Definition Builders.h:56
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:281
IndexType getIndexType()
Definition Builders.cpp:51
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition Builders.cpp:318
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition Operation.h:849
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
Definition Operation.h:397
result_type_range getResultTypes()
Definition Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
user_range getUsers()
Returns a range of all users.
Definition Operation.h:873
user_iterator user_begin()
Definition Operation.h:869
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition GPUDialect.h:131
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
FailureOr< vector::ContractionOp > getUserContract(Operation *op)
Returns the first user of the op that is vector.contract.
Definition MMAUtils.cpp:48
FailureOr< WarpMatrixInfo > getWarpMatrixInfo(Operation *op)
If op is a vector.transfer_write, return the WarpMatrixInfo for the vector operand.
Definition MMAUtils.cpp:56
bool canLowerToWarpMatrixOperation(vector::TransferWriteOp op)
Returns the number of bits in a single tile row.
Definition MMAUtils.cpp:296
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Definition VectorOps.h:156
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Definition VectorOps.h:151
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction wit...
Include the generated interface declarations.
void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, bool useNvGpu=false)
Patterns to transform vector ops into a canonical form to convert to MMA matrix operations.
LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
SliceOptions ForwardSliceOptions
const FrozenRewritePatternSet & patterns
LogicalResult convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter, Operation *rootOp)
Convert vector ops ops nested under rootOp to vector and GPU operaitons compatible with the nvvm....
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
std::unique_ptr< Pass > createConvertVectorToGPUPass(bool useNvGpu=false)
Convert from vector to GPU ops.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult convertVectorToMMAOps(RewriterBase &rewriter, Operation *rootOp)
Convert vector ops to MMA matrix operations nested under rootOp.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This trait tags element-wise ops on vectors or tensors.
TransitiveFilter filter