MLIR  18.0.0git
VectorToGPU.cpp
Go to the documentation of this file.
1 //===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of vector operations to GPU dialect ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include <type_traits>
16 
29 #include "mlir/IR/Builders.h"
30 #include "mlir/IR/BuiltinOps.h"
31 #include "mlir/IR/Region.h"
32 #include "mlir/Pass/Pass.h"
35 #include "mlir/Transforms/Passes.h"
36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/TypeSwitch.h"
38 
39 #define DEBUG_TYPE "vector-to-gpu"
40 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
41 #define DBGSNL() (llvm::dbgs() << "\n")
42 
43 namespace mlir {
44 #define GEN_PASS_DEF_CONVERTVECTORTOGPU
45 #include "mlir/Conversion/Passes.h.inc"
46 } // namespace mlir
47 
48 using namespace mlir;
49 
50 /// For a vector TransferOpType `xferOp`, an empty `indices` vector, and an
51 /// AffineMap representing offsets to apply to indices, the function fills
52 /// `indices` with the original indices plus the offsets. The offsets are
53 /// applied by taking into account the permutation map of the transfer op. If
54 /// the `offsetMap` has dimension placeholders, those should be provided in
55 /// `dimValues`.
56 template <typename TransferOpType>
57 static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp,
58  AffineMap offsetMap, ArrayRef<Value> dimValues,
59  SmallVector<Value, 4> &indices) {
60  indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end());
61  Location loc = xferOp.getLoc();
62  unsigned offsetsIdx = 0;
63  for (auto expr : xferOp.getPermutationMap().getResults()) {
64  if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
65  Value prevIdx = indices[dim.getPosition()];
66  SmallVector<OpFoldResult, 3> dims(dimValues.begin(), dimValues.end());
67  dims.push_back(prevIdx);
68  AffineExpr d0 = rewriter.getAffineDimExpr(offsetMap.getNumDims());
69  indices[dim.getPosition()] = affine::makeComposedAffineApply(
70  rewriter, loc, d0 + offsetMap.getResult(offsetsIdx++), dims);
71  continue;
72  }
73  }
74 }
75 
76 // Return true if the contract op can be convert to MMA matmul.
77 static bool contractSupportsMMAMatrixType(vector::ContractionOp contract,
78  bool useNvGpu) {
79  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
80  auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
81  AffineExpr m, n, k;
82  bindDims(contract.getContext(), m, n, k);
83  auto iteratorTypes = contract.getIteratorTypes().getValue();
84  if (!(vector::isParallelIterator(iteratorTypes[0]) &&
85  vector::isParallelIterator(iteratorTypes[1]) &&
86  vector::isReductionIterator(iteratorTypes[2])))
87  return false;
88 
89  // The contract needs to represent a matmul to be able to convert to
90  // MMAMatrix matmul.
91  if (!useNvGpu &&
92  contract.getIndexingMapsArray() != infer({{m, k}, {k, n}, {m, n}}))
93  return false;
94  if (useNvGpu &&
95  contract.getIndexingMapsArray() != infer({{m, k}, {n, k}, {m, n}}))
96  return false;
97 
98  return true;
99 }
100 
101 // Return true if the given map represents a transposed matrix load,
102 // i.e. (d0, d1, ...) -> (dn-1, dn-2).
103 static bool isTransposeMatrixLoadMap(AffineMap permutationMap) {
104  MLIRContext *ctx = permutationMap.getContext();
105  // Local OpBuilder is fine here, we just build attributes.
106  OpBuilder b(ctx);
107  auto nDim = permutationMap.getNumDims();
108  AffineExpr zero = b.getAffineConstantExpr(0);
109  if (nDim < 2) {
110  // Support transposed+broadcasted cases: affine_map<(d0) -> (d0, 0)>.
111  AffineExpr dim0 = b.getAffineDimExpr(0);
112  return permutationMap == AffineMap::get(1, 0, {dim0, zero}, ctx);
113  }
114 
115  AffineExpr innerDim = b.getAffineDimExpr(nDim - 1);
116  AffineExpr outerDim = b.getAffineDimExpr(nDim - 2);
117  // Support both transposed and transposed+broadcasted cases.
118  return permutationMap == AffineMap::get(nDim, 0, {innerDim, outerDim}, ctx) ||
119  permutationMap == AffineMap::get(nDim, 0, {innerDim, zero}, ctx);
120 }
121 
122 // Return the stide for the second-to-last dimension of |type| if it is a memref
123 // and has a constant stride.
124 static std::optional<int64_t> getStaticallyKnownRowStride(ShapedType type) {
125  auto memrefType = dyn_cast<MemRefType>(type);
126  if (!memrefType)
127  return false;
128  // If the memref is 0 or 1D the horizontal stride is 0.
129  if (memrefType.getRank() < 2)
130  return 0;
131  int64_t offset = 0;
132  SmallVector<int64_t, 2> strides;
133  if (failed(getStridesAndOffset(memrefType, strides, offset)) ||
134  strides.back() != 1)
135  return std::nullopt;
136  int64_t stride = strides[strides.size() - 2];
137  if (stride == ShapedType::kDynamic)
138  return std::nullopt;
139  return stride;
140 }
141 
142 // Return true if the transfer op can be converted to a MMA matrix load.
143 static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
144  if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
145  readOp.getVectorType().getRank() != 2)
146  return false;
147  if (!getStaticallyKnownRowStride(readOp.getShapedType()))
148  return false;
149 
150  // Only allow integer types if the signedness can be inferred.
151  if (readOp.getVectorType().getElementType().isInteger(8))
152  if (!readOp->hasOneUse() || (!isa<arith::ExtSIOp>(*readOp->user_begin()) &&
153  !isa<arith::ExtUIOp>(*readOp->user_begin())))
154  return false;
155 
156  AffineMap map = readOp.getPermutationMap();
157  MLIRContext *ctx = readOp.getContext();
158  AffineExpr innerDim = getAffineDimExpr(map.getNumDims() - 1, ctx);
159  AffineExpr zero = getAffineConstantExpr(0, ctx);
160  auto broadcastInnerDim =
161  AffineMap::get(map.getNumDims(), 0, {zero, innerDim}, ctx);
162  return map.isMinorIdentity() || map == broadcastInnerDim ||
164 }
165 
166 // Return true if the transfer op can be converted to a MMA matrix store.
167 static bool
168 transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
169  // TODO: support 0-d corner case.
170  if (writeOp.getTransferRank() == 0)
171  return false;
172 
173  if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() ||
174  writeOp.getVectorType().getRank() != 2)
175  return false;
176  if (!getStaticallyKnownRowStride(writeOp.getShapedType()))
177  return false;
178  // TODO: Support transpose once it is added to GPU dialect ops.
179  if (!writeOp.getPermutationMap().isMinorIdentity())
180  return false;
181  return true;
182 }
183 
184 /// Return true if the constant is a splat to a 2D vector so that it can be
185 /// converted to a MMA constant matrix op.
186 static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) {
187  auto vecType = dyn_cast<VectorType>(constantOp.getType());
188  if (!vecType || vecType.getRank() != 2)
189  return false;
190  return isa<SplatElementsAttr>(constantOp.getValue());
191 }
192 
193 /// Return true if this is a broadcast from scalar to a 2D vector.
194 static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {
195  return broadcastOp.getResultVectorType().getRank() == 2;
196 }
197 
198 /// Return true if this integer extend op can be folded into a contract op.
199 template <typename ExtOpTy>
200 static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) {
201  if (!isa<vector::TransferReadOp>(extOp.getOperand().getDefiningOp()))
202  return false;
203  return llvm::all_of(extOp->getUsers(), [](Operation *user) {
204  return isa<vector::ContractionOp>(user);
205  });
206 }
207 
208 static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp) { return true; }
209 
210 /// Return the MMA elementwise enum associated with `op` if it is supported.
211 /// Return `std::nullopt` otherwise.
212 static std::optional<gpu::MMAElementwiseOp>
214  if (isa<arith::AddFOp>(op))
215  return gpu::MMAElementwiseOp::ADDF;
216  if (isa<arith::MulFOp>(op))
217  return gpu::MMAElementwiseOp::MULF;
218  if (isa<arith::SubFOp>(op))
219  return gpu::MMAElementwiseOp::SUBF;
220  if (isa<arith::MaximumFOp>(op))
221  return gpu::MMAElementwiseOp::MAXF;
222  if (isa<arith::MinimumFOp>(op))
223  return gpu::MMAElementwiseOp::MINF;
224  if (isa<arith::DivFOp>(op))
225  return gpu::MMAElementwiseOp::DIVF;
226  if (isa<arith::AddIOp>(op))
228  if (isa<arith::MulIOp>(op))
230  if (isa<arith::SubIOp>(op))
232  if (isa<arith::DivSIOp>(op))
233  return gpu::MMAElementwiseOp::DIVS;
234  if (isa<arith::DivUIOp>(op))
235  return gpu::MMAElementwiseOp::DIVU;
236  if (isa<arith::NegFOp>(op))
237  return gpu::MMAElementwiseOp::NEGATEF;
238  if (isa<arith::ExtFOp>(op))
239  return gpu::MMAElementwiseOp::EXTF;
240  return std::nullopt;
241 }
242 
243 /// Return true if the op is supported as elementwise op on MMAMatrix type.
245  return convertElementwiseOpToMMA(op).has_value();
246 }
247 
248 /// Returns true if the extract strided slice op is supported with `mma.sync`
249 /// path.
250 static bool
251 extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op) {
252 
253  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
255  if (failed(warpMatrixInfo))
256  return false;
257 
259  if (failed(contractOp))
260  return false;
261 
262  // Handle vector.extract_strided_slice on registers containing
263  // matrixB and matrixC operands. vector.extract_strided_slice op
264  // is not supported on registers containing matrixA operands.
265  if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B)
266  return (cast<VectorType>(op->getResult(0).getType()) ==
267  cast<VectorType>((*contractOp).getRhs().getType()));
268  if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::C)
269  return (cast<VectorType>(op->getResult(0).getType()) ==
270  cast<VectorType>((*contractOp).getAcc().getType()));
271 
272  return false;
273 }
274 
275 static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
276  if (isa<scf::ForOp, scf::YieldOp>(op))
277  return true;
278  if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
279  return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferRead)
280  : transferReadSupportsMMAMatrixType(transferRead);
281  if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
282  return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferWrite)
283  : transferWriteSupportsMMAMatrixType(transferWrite);
284  if (auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op))
285  return useNvGpu &&
286  extractStridedSliceSupportsMMAMatrixType(extractStridedSlice);
287  if (auto contract = dyn_cast<vector::ContractionOp>(op))
288  return contractSupportsMMAMatrixType(contract, useNvGpu);
289  if (auto constant = dyn_cast<arith::ConstantOp>(op))
290  return constantSupportsMMAMatrixType(constant);
291  if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
293  if (auto signedExtend = dyn_cast<arith::ExtSIOp>(op))
294  return integerExtendSupportsMMAMatrixType<arith::ExtSIOp>(signedExtend);
295  if (auto unsignedExtend = dyn_cast<arith::ExtUIOp>(op))
296  return integerExtendSupportsMMAMatrixType<arith::ExtUIOp>(unsignedExtend);
297  if (auto fpExtend = dyn_cast<arith::ExtFOp>(op))
298  return fpExtendSupportsMMAMatrixType(fpExtend);
300 }
301 
302 /// Return an unsorted slice handling scf.for region differently than
303 /// `getSlice`. In scf.for we only want to include as part of the slice elements
304 /// that are part of the use/def chain.
307  ForwardSliceOptions forwardSliceOptions) {
309  slice.insert(op);
310  unsigned currentIndex = 0;
311  SetVector<Operation *> backwardSlice;
312  SetVector<Operation *> forwardSlice;
313  while (currentIndex != slice.size()) {
314  auto *currentOp = (slice)[currentIndex];
315  // Compute and insert the backwardSlice starting from currentOp.
316  backwardSlice.clear();
317  getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
318  slice.insert(backwardSlice.begin(), backwardSlice.end());
319 
320  // Compute and insert the forwardSlice starting from currentOp.
321  forwardSlice.clear();
322  // Special case for ForOp, we don't want to include the whole region but
323  // only the value using the region arguments.
324  // TODO: We should refine this to only care about the region arguments being
325  // converted to matrix type.
326  if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) {
327  for (Value forOpResult : forOp.getResults())
328  getForwardSlice(forOpResult, &forwardSlice, forwardSliceOptions);
329  for (BlockArgument &arg : forOp.getRegionIterArgs())
330  getForwardSlice(arg, &forwardSlice, forwardSliceOptions);
331  } else {
332  getForwardSlice(currentOp, &forwardSlice, forwardSliceOptions);
333  }
334  slice.insert(forwardSlice.begin(), forwardSlice.end());
335  ++currentIndex;
336  }
337  return slice;
338 }
339 
340 // Analyze slice of operations based on convert op to figure out if the whole
341 // slice can be converted to MMA operations.
343  bool useNvGpu) {
344  auto hasVectorDest = [](Operation *op) {
345  return llvm::any_of(op->getResultTypes(),
346  [](Type t) { return isa<VectorType>(t); });
347  };
348  BackwardSliceOptions backwardSliceOptions;
349  backwardSliceOptions.filter = hasVectorDest;
350 
351  auto hasVectorSrc = [](Operation *op) {
352  return llvm::any_of(op->getOperandTypes(),
353  [](Type t) { return isa<VectorType>(t); });
354  };
355  ForwardSliceOptions forwardSliceOptions;
356  forwardSliceOptions.filter = hasVectorSrc;
357 
358  SetVector<Operation *> opToConvert;
359  op->walk([&](vector::ContractionOp contract) {
360  if (opToConvert.contains(contract.getOperation()))
361  return;
362  SetVector<Operation *> dependentOps =
363  getSliceContract(contract, backwardSliceOptions, forwardSliceOptions);
364  // If any instruction cannot use MMA matrix type drop the whole
365  // chain. MMA matrix are stored in an opaque type so they cannot be used
366  // by all operations.
367  if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) {
368  if (!supportsMMaMatrixType(op, useNvGpu)) {
369  LLVM_DEBUG(DBGS() << "cannot convert op: " << *op << "\n");
370  return true;
371  }
372  return false;
373  }))
374  return;
375 
376  opToConvert.insert(dependentOps.begin(), dependentOps.end());
377  });
378  // Sort the operations so that we can convert them in topological order.
379  return topologicalSort(opToConvert);
380 }
381 
382 namespace {
383 // Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted
384 // to MMA matmul.
385 struct PrepareContractToGPUMMA
386  : public OpRewritePattern<vector::ContractionOp> {
388 
389  LogicalResult matchAndRewrite(vector::ContractionOp op,
390  PatternRewriter &rewriter) const override {
391  Location loc = op.getLoc();
392  Value lhs = op.getLhs(), rhs = op.getRhs(), res = op.getAcc();
393 
394  // Set up the parallel/reduction structure in right form.
395  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
396  auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
397  AffineExpr m, n, k;
398  bindDims(rewriter.getContext(), m, n, k);
399  static constexpr std::array<int64_t, 2> perm = {1, 0};
400  auto iteratorTypes = op.getIteratorTypes().getValue();
401  SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
402  if (!(vector::isParallelIterator(iteratorTypes[0]) &&
403  vector::isParallelIterator(iteratorTypes[1]) &&
404  vector::isReductionIterator(iteratorTypes[2])))
405  return rewriter.notifyMatchFailure(op, "not a gemm contraction");
406  //
407  // Two outer parallel, one inner reduction (matmat flavor).
408  //
409  // This is the classical row-major matmul, nothing to do.
410  if (maps == infer({{m, k}, {k, n}, {m, n}}))
411  return rewriter.notifyMatchFailure(op, "contraction already prepared");
412  if (maps == infer({{m, k}, {n, k}, {m, n}})) {
413  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
414  } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
415  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
416  } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
417  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
418  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
419  } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
420  std::swap(rhs, lhs);
421  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
422  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
423  } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
424  std::swap(rhs, lhs);
425  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
426  } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
427  std::swap(lhs, rhs);
428  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
429  } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
430  std::swap(lhs, rhs);
431  } else {
432  // TODO: llvm_unreachable ?
433  return rewriter.notifyMatchFailure(op, "unexpected contraction case");
434  }
435  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
436  op, lhs, rhs, res,
437  rewriter.getAffineMapArrayAttr(infer({{m, k}, {k, n}, {m, n}})),
438  op.getIteratorTypes());
439  return success();
440  }
441 };
442 
443 // Fold transpose op into the transfer read op. Nvgpu mma.sync op only supports
444 // row-, column-, and row-major layout for matrixA, matrixB, and matrixC,
445 // respectively. We can fold the transpose operation when loading the data from
446 // Shared Memory to registers.
447 struct CombineTransferReadOpTranspose final
448  : public OpRewritePattern<vector::TransposeOp> {
450 
451  LogicalResult matchAndRewrite(vector::TransposeOp op,
452  PatternRewriter &rewriter) const override {
453  // Look through integer extend ops.
454  Value source = op.getVector();
455  Type resultType = op.getType();
456  Operation *extOp;
457  if ((extOp = source.getDefiningOp<arith::ExtSIOp>()) ||
458  (extOp = source.getDefiningOp<arith::ExtUIOp>())) {
459  source = extOp->getOperand(0);
460  resultType =
461  VectorType::get(cast<VectorType>(resultType).getShape(),
462  cast<VectorType>(source.getType()).getElementType());
463  }
464 
465  auto transferReadOp = source.getDefiningOp<vector::TransferReadOp>();
466  if (!transferReadOp)
467  return rewriter.notifyMatchFailure(op, "no transfer read");
468 
469  // TODO: support 0-d corner case.
470  if (transferReadOp.getTransferRank() == 0)
471  return rewriter.notifyMatchFailure(op, "0-D transfer read");
472 
473  if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim())
474  return rewriter.notifyMatchFailure(op, "not inbounds transfer read");
475 
476  AffineMap permutationMap =
477  AffineMap::getPermutationMap(op.getPermutation(), op.getContext());
478  AffineMap newMap =
479  permutationMap.compose(transferReadOp.getPermutationMap());
480 
481  auto loc = op.getLoc();
482  Value result =
483  rewriter
484  .create<vector::TransferReadOp>(
485  loc, resultType, transferReadOp.getSource(),
486  transferReadOp.getIndices(), AffineMapAttr::get(newMap),
487  transferReadOp.getPadding(), transferReadOp.getMask(),
488  transferReadOp.getInBoundsAttr())
489  .getResult();
490 
491  // Fuse through the integer extend op.
492  if (extOp) {
493  if (isa<arith::ExtSIOp>(extOp))
494  result = rewriter.create<arith::ExtSIOp>(loc, op.getType(), result)
495  .getResult();
496  else
497  result = rewriter.create<arith::ExtUIOp>(loc, op.getType(), result)
498  .getResult();
499  }
500 
501  rewriter.replaceOp(op, result);
502  return success();
503  }
504 };
505 
506 } // namespace
507 
508 // MMA types have different layout based on how they are used in matmul ops.
509 // Figure the right layout to use by looking at op uses.
510 // TODO: Change the GPU dialect to abstract the layout at the this level and
511 // only care about it during lowering to NVVM.
512 static const char *inferFragType(Operation *op) {
513  for (Operation *users : op->getUsers()) {
514  auto contract = dyn_cast<vector::ContractionOp>(users);
515  if (!contract)
516  continue;
517  assert(op->getNumResults() == 1);
518  if (contract.getLhs() == op->getResult(0))
519  return "AOp";
520  if (contract.getRhs() == op->getResult(0))
521  return "BOp";
522  }
523  return "COp";
524 }
525 
526 static LogicalResult
527 convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
528  llvm::DenseMap<Value, Value> &valueMapping) {
529  OpBuilder::InsertionGuard g(rewriter);
530  rewriter.setInsertionPoint(op);
531 
532  assert(op.getTransferRank() > 0 && "unexpected 0-d transfer");
534  "expected convertible operation");
535 
536  std::optional<int64_t> stride =
537  getStaticallyKnownRowStride(op.getShapedType());
538  if (!stride.has_value()) {
539  LLVM_DEBUG(DBGS() << "no stride\n");
540  return rewriter.notifyMatchFailure(op, "no stride");
541  }
542 
543  AffineMap map = op.getPermutationMap();
544  bool isTranspose = isTransposeMatrixLoadMap(map);
545 
546  // Handle broadcast by setting the stride to 0.
547  if (auto cstExpr = dyn_cast<AffineConstantExpr>(map.getResult(isTranspose))) {
548  assert(cstExpr.getValue() == 0);
549  stride = 0;
550  }
551 
552  Value mappingResult = op.getResult();
553  auto elType = op.getVectorType().getElementType();
554  const char *fragType = inferFragType(op);
555  if (op->hasOneUse()) {
556  auto user = *op->user_begin();
557  // Infer the signedness of the mma type from the integer extend.
558  bool isSignedExtend = isa<arith::ExtSIOp>(user);
559  if (isSignedExtend || isa<arith::ExtUIOp>(user)) {
560  elType = IntegerType::get(
561  op.getContext(), cast<IntegerType>(elType).getWidth(),
562  isSignedExtend ? IntegerType::Signed : IntegerType::Unsigned);
563  mappingResult = user->getResult(0);
564  fragType = inferFragType(user);
565  }
566  }
567  gpu::MMAMatrixType type =
568  gpu::MMAMatrixType::get(op.getVectorType().getShape(), elType, fragType);
569  Value load = rewriter.create<gpu::SubgroupMmaLoadMatrixOp>(
570  op.getLoc(), type, op.getSource(), op.getIndices(),
571  rewriter.getIndexAttr(*stride),
572  isTranspose ? rewriter.getUnitAttr() : UnitAttr());
573  valueMapping[mappingResult] = load;
574 
575  LLVM_DEBUG(DBGS() << "transfer read to: " << load << "\n");
576  return success();
577 }
578 
579 static LogicalResult
580 convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
581  llvm::DenseMap<Value, Value> &valueMapping) {
582  OpBuilder::InsertionGuard g(rewriter);
583  rewriter.setInsertionPoint(op);
584 
586  std::optional<int64_t> stride =
587  getStaticallyKnownRowStride(op.getShapedType());
588  if (!stride.has_value()) {
589  LLVM_DEBUG(DBGS() << "no stride\n");
590  return rewriter.notifyMatchFailure(op, "no stride");
591  }
592 
593  auto it = valueMapping.find(op.getVector());
594  if (it == valueMapping.end()) {
595  LLVM_DEBUG(DBGS() << "no mapping\n");
596  return rewriter.notifyMatchFailure(op, "no mapping");
597  }
598 
599  Value matrix = it->second;
600  auto store = rewriter.create<gpu::SubgroupMmaStoreMatrixOp>(
601  op.getLoc(), matrix, op.getSource(), op.getIndices(),
602  rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr());
603  (void)store;
604 
605  LLVM_DEBUG(DBGS() << "transfer write to: " << store << "\n");
606 
607  LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
608  rewriter.eraseOp(op);
609  return success();
610 }
611 
612 /// Returns the vector type which represents a matrix fragment.
613 static VectorType
616  regInfo.elementsPerRegister};
617  Type elType = regInfo.registerLLVMType;
618  if (auto vecType = dyn_cast<VectorType>(elType))
619  elType = vecType.getElementType();
620  return VectorType::get(shape, elType);
621 }
622 
623 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
624 static LogicalResult
625 convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op,
626  llvm::DenseMap<Value, Value> &valueMapping) {
627  OpBuilder::InsertionGuard g(rewriter);
628  rewriter.setInsertionPoint(op);
629 
630  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
632  if (failed(warpMatrixInfo)) {
633  LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
634  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
635  }
636 
638  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
639  if (failed(regInfo)) {
640  LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
641  return rewriter.notifyMatchFailure(op, "not mma sync reg info");
642  }
643 
644  VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
645  auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
646  if (!dense) {
647  LLVM_DEBUG(DBGS() << "not a splat\n");
648  return rewriter.notifyMatchFailure(op, "not a splat");
649  }
650 
651  Value result = rewriter.create<arith::ConstantOp>(
652  op.getLoc(), vectorType,
653  DenseElementsAttr::get(vectorType, dense.getSplatValue<Attribute>()));
654  valueMapping[op.getResult()] = result;
655  return success();
656 }
657 
658 /// Check if the loaded matrix operand requires transposed.
659 /// Transposed Map Example:
660 /// Example 1 : (..., d0, d1) -> (d1 * 1, d0 * 2)
661 /// Example 2 : (d0, d1, d2, d3) -> (d3, d2)
662 /// The code below checks if the output 2D is transposed using a generalized
663 /// version : (d0, d1, dn, ..., dm, ...) -> (dm, dn)
664 /// Returns : true; if m > n, false o.w.
665 static FailureOr<bool> isTransposed(vector::TransferReadOp op) {
666  mlir::AffineMap map = op.getPermutationMap();
667 
668  if (map.getNumResults() != 2) {
669  LLVM_DEBUG(DBGS() << "Failed because the result of `vector.transfer_read` "
670  "is not a 2d operand\n");
671  return failure();
672  }
673 
674  // Output 2D matrix dimensions in the order of d0, d1.
675  mlir::AffineExpr dM = map.getResult(0);
676  mlir::AffineExpr dN = map.getResult(1);
677 
678  // Find the position of these expressions in the input.
679  auto exprM = dyn_cast<AffineDimExpr>(dM);
680  auto exprN = dyn_cast<AffineDimExpr>(dN);
681 
682  if (!exprM || !exprN) {
683  LLVM_DEBUG(DBGS() << "Failed because expressions are not affine dim "
684  "expressions, then transpose cannot be determined.\n");
685  return failure();
686  }
687 
688  return exprM.getPosition() > exprN.getPosition();
689 }
690 
691 static LogicalResult
692 creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
693  llvm::DenseMap<Value, Value> &valueMapping) {
694  OpBuilder::InsertionGuard g(rewriter);
695  rewriter.setInsertionPoint(op);
696  Location loc = op->getLoc();
697 
698  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
700  if (failed(warpMatrixInfo)) {
701  LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
702  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
703  }
704 
706  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
707  if (failed(regInfo)) {
708  LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
709  return rewriter.notifyMatchFailure(op, "not mma sync reg info");
710  }
711 
712  FailureOr<bool> transpose = isTransposed(op);
713  if (failed(transpose)) {
714  LLVM_DEBUG(DBGS() << "failed to determine the transpose\n");
715  return rewriter.notifyMatchFailure(
716  op, "Op should likely not be converted to a nvgpu.ldmatrix call.");
717  }
718 
720  nvgpu::getLdMatrixParams(*warpMatrixInfo, *transpose);
721 
722  if (failed(params)) {
723  LLVM_DEBUG(
724  DBGS()
725  << "failed to convert vector.transfer_read to ldmatrix. "
726  << "Op should likely not be converted to a nvgpu.ldmatrix call.\n");
727  return rewriter.notifyMatchFailure(
728  op, "failed to convert vector.transfer_read to ldmatrix; this op "
729  "likely should not be converted to a nvgpu.ldmatrix call.");
730  }
731 
732  // Adjust the load offset.
733  auto laneId = rewriter.create<gpu::LaneIdOp>(loc);
734  FailureOr<AffineMap> offsets =
735  nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params);
736  if (failed(offsets)) {
737  LLVM_DEBUG(DBGS() << "no offsets\n");
738  return rewriter.notifyMatchFailure(op, "no offsets");
739  }
740 
741  VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
742 
743  SmallVector<Value, 4> indices;
744  getXferIndices<vector::TransferReadOp>(rewriter, op, *offsets, {laneId},
745  indices);
746 
747  nvgpu::LdMatrixOp newOp = rewriter.create<nvgpu::LdMatrixOp>(
748  loc, vectorType, op.getSource(), indices, *transpose, params->numTiles);
749  valueMapping[op] = newOp->getResult(0);
750  return success();
751 }
752 
753 static LogicalResult
754 createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
755  llvm::DenseMap<Value, Value> &valueMapping) {
756  OpBuilder::InsertionGuard g(rewriter);
757  rewriter.setInsertionPoint(op);
758 
759  Location loc = op.getLoc();
760  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
762  if (failed(warpMatrixInfo))
763  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
765  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
766  if (failed(regInfo)) {
767  return rewriter.notifyMatchFailure(
768  op, "Failed to deduce register fragment type during "
769  "conversion to distributed non-ldmatrix compatible load");
770  }
771 
772  Value laneId = rewriter.create<gpu::LaneIdOp>(loc);
773  SmallVector<Value, 4> elements;
774 
775  // This is the individual element type.
776  Type loadedElType = regInfo->registerLLVMType;
777  VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
778 
779  Value fill = rewriter.create<arith::ConstantOp>(
780  op.getLoc(), vectorType.getElementType(),
781  rewriter.getZeroAttr(vectorType.getElementType()));
782  Value result =
783  rewriter.create<vector::SplatOp>(op.getLoc(), fill, vectorType);
784 
785  bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
786 
787  // If we are not transposing, then we can use vectorized loads. Otherwise, we
788  // must load each element individually.
789  if (!isTransposeLoad) {
790  if (!isa<VectorType>(loadedElType)) {
791  loadedElType = VectorType::get({1}, loadedElType);
792  }
793 
794  for (int i = 0; i < vectorType.getShape()[0]; i++) {
796  rewriter, op.getLoc(), *warpMatrixInfo);
797  if (failed(coords))
798  return rewriter.notifyMatchFailure(op, "no coords");
799 
800  Value logicalValueId = rewriter.create<arith::ConstantOp>(
801  loc, rewriter.getIndexType(),
802  rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
803  SmallVector<Value, 4> newIndices;
804  getXferIndices<vector::TransferReadOp>(
805  rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
806 
807  Value el = rewriter.create<vector::LoadOp>(loc, loadedElType,
808  op.getSource(), newIndices);
809  result = rewriter.create<vector::InsertOp>(loc, el, result, i);
810  }
811  } else {
812  if (auto vecType = dyn_cast<VectorType>(loadedElType)) {
813  loadedElType = vecType.getElementType();
814  }
815  for (int i = 0; i < vectorType.getShape()[0]; i++) {
816  for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1];
817  innerIdx++) {
818 
819  Value logicalValueId = rewriter.create<arith::ConstantOp>(
820  loc, rewriter.getIndexType(),
821  rewriter.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx));
823  rewriter, op.getLoc(), *warpMatrixInfo);
824  if (failed(coords))
825  return rewriter.notifyMatchFailure(op, "no coords");
826 
827  SmallVector<Value, 4> newIndices;
828  getXferIndices<vector::TransferReadOp>(
829  rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
830  Value el = rewriter.create<memref::LoadOp>(op.getLoc(), loadedElType,
831  op.getSource(), newIndices);
832  result = rewriter.create<vector::InsertOp>(
833  op.getLoc(), el, result, ArrayRef<int64_t>{i, innerIdx});
834  }
835  }
836  }
837 
838  valueMapping[op.getResult()] = result;
839  return success();
840 }
841 
842 /// Return true if this is a shared memory memref type.
843 static bool isSharedMemory(MemRefType type) {
844  auto addressSpace =
845  dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
846  if (addressSpace &&
847  addressSpace.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace())
848  return true;
849  return false;
850 }
851 
852 /// Converts a `vector.transfer_read` operation directly to either a
853 /// `vector.load` or a `nvgpu.ldmatrix` operation. This function should only be
854 /// used when converting to `nvgpu.mma.sync` operations.
855 static LogicalResult
856 convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op,
857  llvm::DenseMap<Value, Value> &valueMapping) {
858  OpBuilder::InsertionGuard g(rewriter);
859  rewriter.setInsertionPoint(op);
860 
861  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
863  if (failed(warpMatrixInfo))
864  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
865 
866  bool isLdMatrixCompatible =
867  isSharedMemory(cast<MemRefType>(op.getSource().getType())) &&
868  nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128;
869 
870  VectorType vecTy = op.getVectorType();
871  int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth();
872 
873  // When we are transposing the B operand, ldmatrix will only work if we have
874  // at least 8 rows to read and the width to read for the transpose is 128
875  // bits.
876  if (!op.getPermutationMap().isMinorIdentity() &&
877  (bitWidth != 16 || vecTy.getDimSize(1) < 8 ||
878  vecTy.getDimSize(0) * bitWidth < 128))
879  isLdMatrixCompatible = false;
880 
881  if (!isLdMatrixCompatible)
882  return createNonLdMatrixLoads(rewriter, op, valueMapping);
883 
884  return creatLdMatrixCompatibleLoads(rewriter, op, valueMapping);
885 }
886 
887 static LogicalResult
888 convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
889  llvm::DenseMap<Value, Value> &valueMapping) {
890  OpBuilder::InsertionGuard g(rewriter);
891  rewriter.setInsertionPoint(op);
892 
893  Location loc = op->getLoc();
894  auto it = valueMapping.find(op.getVector());
895  if (it == valueMapping.end())
896  return rewriter.notifyMatchFailure(op, "no mapping");
897  Value matrix = it->second;
898 
899  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
901  if (failed(warpMatrixInfo))
902  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
904  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
905  if (failed(regInfo))
906  return rewriter.notifyMatchFailure(op, "not mma sync reg info");
907 
908  VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
909  Value laneId = rewriter.create<gpu::LaneIdOp>(loc);
910 
911  for (unsigned i = 0; i < vectorType.getShape()[0]; i++) {
912  Value logicalValueId = rewriter.create<arith::ConstantOp>(
913  loc, rewriter.getIndexType(),
914  rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
916  rewriter, op.getLoc(), *warpMatrixInfo);
917  if (failed(coords))
918  return rewriter.notifyMatchFailure(op, "no coords");
919 
920  Value el =
921  rewriter.create<vector::ExtractOp>(loc, matrix, ArrayRef<int64_t>{i});
922  SmallVector<Value, 4> newIndices;
923  getXferIndices<vector::TransferWriteOp>(
924  rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
925  rewriter.create<vector::StoreOp>(loc, el, op.getSource(), newIndices);
926  }
927 
928  LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
929  rewriter.eraseOp(op);
930  return success();
931 }
932 
933 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
934  SmallVectorImpl<int64_t> &results) {
935  for (auto attr : arrayAttr)
936  results.push_back(cast<IntegerAttr>(attr).getInt());
937 }
938 
939 static LogicalResult
941  vector::ExtractStridedSliceOp op,
942  llvm::DenseMap<Value, Value> &valueMapping) {
943  OpBuilder::InsertionGuard g(rewriter);
944  rewriter.setInsertionPoint(op);
945 
946  Location loc = op->getLoc();
947 
948  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
950  if (failed(warpMatrixInfo))
951  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
952 
953  FailureOr<nvgpu::FragmentElementInfo> mmaSyncFragmentInfo =
954  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
955  if (failed(mmaSyncFragmentInfo))
956  return rewriter.notifyMatchFailure(op, "no mmaSyncFragmentInfo");
957 
958  // Find the vector.transer_read whose result vector is being sliced.
959  auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>();
960  if (!transferReadOp)
961  return rewriter.notifyMatchFailure(op, "no transfer read");
962 
963  warpMatrixInfo = nvgpu::getWarpMatrixInfo(transferReadOp);
964  if (failed(warpMatrixInfo))
965  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
966 
968  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
969  if (failed(ldFragmentInfo))
970  return rewriter.notifyMatchFailure(op, "no ldFragmentInfo");
971 
972  assert(
973  (mmaSyncFragmentInfo->elementsPerRegister ==
974  ldFragmentInfo->elementsPerRegister) &&
975  "Number of elements per register should be same for load and mma.sync");
976 
977  // Create vector.extract_strided_slice op for thread-owned fragments.
978  std::array<int64_t, 2> strides = {1,
979  1}; // stride for extract slice is always 1.
980  std::array<int64_t, 2> sliceShape = {
981  mmaSyncFragmentInfo->numRegistersPerFragment,
982  mmaSyncFragmentInfo->elementsPerRegister};
983  auto it = valueMapping.find(transferReadOp);
984  if (it == valueMapping.end())
985  return rewriter.notifyMatchFailure(op, "no mapping");
986  auto sourceVector = it->second;
987 
988  // offset and sizes at warp-level of onwership.
989  SmallVector<int64_t> offsets;
990  populateFromInt64AttrArray(op.getOffsets(), offsets);
991 
992  SmallVector<int64_t> sizes;
993  populateFromInt64AttrArray(op.getSizes(), sizes);
994  ArrayRef<int64_t> warpVectorShape = op.getSourceVectorType().getShape();
995 
996  // Compute offset in vector registers. Note that the mma.sync vector registers
997  // are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector
998  // registers can only be sliced along numberOfFragments, i.e., sliceOffset[0].
999  std::array<int64_t, 2> sliceOffset = {0, 0};
1000 
1001  if (offsets[0] && offsets[1])
1002  return op->emitError() << "Slicing fragments in 2D is not supported. ";
1003  if (offsets[0])
1004  sliceOffset[0] = (warpVectorShape[0] / offsets[0]);
1005  else if (offsets[1])
1006  sliceOffset[0] = (warpVectorShape[1] / offsets[1]);
1007 
1008  Value newOp = rewriter.create<vector::ExtractStridedSliceOp>(
1009  loc, sourceVector, sliceOffset, sliceShape, strides);
1010 
1011  valueMapping[op] = newOp;
1012  return success();
1013 }
1014 
1015 static LogicalResult
1016 convertContractOp(RewriterBase &rewriter, vector::ContractionOp op,
1017  llvm::DenseMap<Value, Value> &valueMapping) {
1018  OpBuilder::InsertionGuard g(rewriter);
1019  rewriter.setInsertionPoint(op);
1020 
1021  auto itA = valueMapping.find(op.getLhs());
1022  auto itB = valueMapping.find(op.getRhs());
1023  auto itC = valueMapping.find(op.getAcc());
1024  if (itA == valueMapping.end() || itB == valueMapping.end() ||
1025  itC == valueMapping.end())
1026  return rewriter.notifyMatchFailure(op, "no mapping");
1027  Value opA = itA->second, opB = itB->second, opC = itC->second;
1028  Value matmul = rewriter.create<gpu::SubgroupMmaComputeOp>(
1029  op.getLoc(), opC.getType(), opA, opB, opC, /*a_transpose=*/UnitAttr(),
1030  /*b_transpose=*/UnitAttr());
1031  valueMapping[op.getResult()] = matmul;
1032  return success();
1033 }
1034 
1035 static LogicalResult
1036 convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op,
1037  llvm::DenseMap<Value, Value> &valueMapping) {
1038  OpBuilder::InsertionGuard g(rewriter);
1039  rewriter.setInsertionPoint(op);
1040 
1041  auto itA = valueMapping.find(op.getLhs());
1042  auto itB = valueMapping.find(op.getRhs());
1043  auto itC = valueMapping.find(op.getAcc());
1044  if (itA == valueMapping.end() || itB == valueMapping.end() ||
1045  itC == valueMapping.end())
1046  return rewriter.notifyMatchFailure(op, "no mapping");
1047  Value opA = itA->second, opB = itB->second, opC = itC->second;
1048  int64_t m = cast<VectorType>(op.getLhs().getType()).getShape()[0];
1049  int64_t n = cast<VectorType>(op.getRhs().getType()).getShape()[0];
1050  int64_t k = cast<VectorType>(op.getLhs().getType()).getShape()[1];
1051  Value matmul = rewriter.create<nvgpu::MmaSyncOp>(
1052  op.getLoc(), opA, opB, opC, rewriter.getI64ArrayAttr({m, n, k}));
1053  valueMapping[op.getResult()] = matmul;
1054  return success();
1055 }
1056 
1057 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
1058 static LogicalResult
1059 convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op,
1060  llvm::DenseMap<Value, Value> &valueMapping) {
1061  OpBuilder::InsertionGuard g(rewriter);
1062  rewriter.setInsertionPoint(op);
1063 
1064  assert(constantSupportsMMAMatrixType(op));
1065 
1066  auto splat =
1067  cast<SplatElementsAttr>(op.getValue()).getSplatValue<TypedAttr>();
1068  auto scalarConstant =
1069  rewriter.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat);
1070  const char *fragType = inferFragType(op);
1071  auto vecType = cast<VectorType>(op.getType());
1073  vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1074  auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(
1075  op.getLoc(), type, scalarConstant);
1076  valueMapping[op.getResult()] = matrix;
1077  return success();
1078 }
1079 
1080 /// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
1081 static LogicalResult
1082 convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op,
1083  llvm::DenseMap<Value, Value> &valueMapping) {
1084  OpBuilder::InsertionGuard g(rewriter);
1085  rewriter.setInsertionPoint(op);
1086 
1087  assert(broadcastSupportsMMAMatrixType(op));
1088 
1089  const char *fragType = inferFragType(op);
1090  auto vecType = op.getResultVectorType();
1092  vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1093  auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(
1094  op.getLoc(), type, op.getSource());
1095  valueMapping[op.getResult()] = matrix;
1096  return success();
1097 }
1098 
1099 // Replace ForOp with a new ForOp with extra operands. The YieldOp is not
1100 // updated and needs to be updated separately for the loop to be correct.
1101 static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
1102  scf::ForOp loop,
1103  ValueRange newInitArgs) {
1104  OpBuilder::InsertionGuard g(rewriter);
1105  rewriter.setInsertionPoint(loop);
1106 
1107  // Create a new loop before the existing one, with the extra operands.
1108  rewriter.setInsertionPoint(loop);
1109  auto operands = llvm::to_vector<4>(loop.getInitArgs());
1110  llvm::append_range(operands, newInitArgs);
1111  scf::ForOp newLoop = rewriter.create<scf::ForOp>(
1112  loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
1113  operands);
1114  newLoop.getBody()->erase();
1115 
1116  newLoop.getRegion().getBlocks().splice(
1117  newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks());
1118  for (Value operand : newInitArgs)
1119  newLoop.getBody()->addArgument(operand.getType(), operand.getLoc());
1120 
1121  for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
1122  loop.getNumResults())))
1123  rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
1124 
1125  LLVM_DEBUG(DBGS() << "newLoop now: " << newLoop << "\n");
1126  LLVM_DEBUG(DBGS() << "stripped scf.for: " << loop << "\n");
1127  LLVM_DEBUG(DBGS() << "erase: " << loop);
1128 
1129  rewriter.eraseOp(loop);
1130  return newLoop;
1131 }
1132 
1133 static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op,
1134  llvm::DenseMap<Value, Value> &valueMapping) {
1135  OpBuilder::InsertionGuard g(rewriter);
1136  rewriter.setInsertionPoint(op);
1137 
1138  SmallVector<Value> newOperands;
1140  for (const auto &operand : llvm::enumerate(op.getInitArgs())) {
1141  auto it = valueMapping.find(operand.value());
1142  if (it == valueMapping.end()) {
1143  LLVM_DEBUG(DBGS() << "no value mapping for: " << operand.value() << "\n");
1144  continue;
1145  }
1146  argMapping.push_back(std::make_pair(
1147  operand.index(), op.getInitArgs().size() + newOperands.size()));
1148  newOperands.push_back(it->second);
1149  }
1150 
1151  scf::ForOp newForOp = replaceForOpWithNewSignature(rewriter, op, newOperands);
1152  Block &loopBody = *newForOp.getBody();
1153  for (auto mapping : argMapping) {
1154  valueMapping[newForOp.getResult(mapping.first)] =
1155  newForOp.getResult(mapping.second);
1156  valueMapping[loopBody.getArgument(mapping.first +
1157  newForOp.getNumInductionVars())] =
1158  loopBody.getArgument(mapping.second + newForOp.getNumInductionVars());
1159  }
1160 
1161  LLVM_DEBUG(DBGS() << "scf.for to: " << newForOp << "\n");
1162  return success();
1163 }
1164 
1165 static LogicalResult
1166 convertYieldOp(RewriterBase &rewriter, scf::YieldOp op,
1167  llvm::DenseMap<Value, Value> &valueMapping) {
1168  OpBuilder::InsertionGuard g(rewriter);
1169  rewriter.setInsertionPoint(op);
1170 
1171  auto loop = cast<scf::ForOp>(op->getParentOp());
1172  auto yieldOperands = llvm::to_vector<4>(op.getOperands());
1173  for (const auto &operand : llvm::enumerate(op.getOperands())) {
1174  auto it = valueMapping.find(operand.value());
1175  if (it == valueMapping.end())
1176  continue;
1177  // Replace the yield of old value with the for op argument to make it easier
1178  // to remove the dead code.
1179  yieldOperands[operand.index()] = loop.getInitArgs()[operand.index()];
1180  yieldOperands.push_back(it->second);
1181  }
1182  rewriter.create<scf::YieldOp>(op.getLoc(), yieldOperands);
1183 
1184  LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
1185  rewriter.eraseOp(op);
1186  return success();
1187 }
1188 
1189 /// Convert an elementwise op to the equivalent elementwise op on MMA matrix.
1190 static LogicalResult
1192  gpu::MMAElementwiseOp opType,
1193  llvm::DenseMap<Value, Value> &valueMapping) {
1194  OpBuilder::InsertionGuard g(rewriter);
1195  rewriter.setInsertionPoint(op);
1196 
1197  SmallVector<Value> matrixOperands;
1198  for (Value operand : op->getOperands()) {
1199  auto it = valueMapping.find(operand);
1200  if (it == valueMapping.end())
1201  return rewriter.notifyMatchFailure(op, "no mapping");
1202  matrixOperands.push_back(it->second);
1203  }
1204  auto resultType = matrixOperands[0].getType().cast<gpu::MMAMatrixType>();
1205  if (opType == gpu::MMAElementwiseOp::EXTF) {
1206  // The floating point extension case has a different result type.
1207  auto vectorType = op->getResultTypes()[0].cast<VectorType>();
1208  resultType = gpu::MMAMatrixType::get(resultType.getShape(),
1209  vectorType.getElementType(),
1210  resultType.getOperand());
1211  }
1212 
1213  Value newOp = rewriter.create<gpu::SubgroupMmaElementwiseOp>(
1214  op->getLoc(), resultType, matrixOperands, opType);
1215  valueMapping[op->getResult(0)] = newOp;
1216  return success();
1217 }
1218 
1220  bool useNvGpu) {
1221  if (!useNvGpu) {
1222  patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>(
1223  patterns.getContext());
1224  return;
1225  }
1227  patterns.add<CombineTransferReadOpTranspose>(patterns.getContext());
1228 }
1229 
1231  Operation *rootOp) {
1232  SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/false);
1233  llvm::DenseMap<Value, Value> valueMapping;
1234 
1235  auto globalRes = LogicalResult::success();
1236  for (Operation *op : ops) {
1237  LLVM_DEBUG(DBGS() << "Process op: " << *op << "\n");
1238  // Apparently callers do not want to early exit on failure here.
1239  auto res = LogicalResult::success();
1240  if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
1241  res = convertTransferReadOp(rewriter, transferRead, valueMapping);
1242  } else if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) {
1243  res = convertTransferWriteOp(rewriter, transferWrite, valueMapping);
1244  } else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) {
1245  res = convertContractOp(rewriter, contractOp, valueMapping);
1246  } else if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
1247  res = convertConstantOp(rewriter, constantOp, valueMapping);
1248  } else if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) {
1249  res = convertBroadcastOp(rewriter, broadcastOp, valueMapping);
1250  } else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
1251  res = convertForOp(rewriter, forOp, valueMapping);
1252  } else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
1253  res = convertYieldOp(rewriter, yieldOp, valueMapping);
1254  } else if (auto elementwiseType = convertElementwiseOpToMMA(op)) {
1255  res = convertElementwiseOp(rewriter, op, *elementwiseType, valueMapping);
1256  }
1257  if (failed(res))
1258  globalRes = failure();
1259  }
1260  return globalRes;
1261 }
1262 
1264  Operation *rootOp) {
1265  SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/true);
1266  llvm::DenseMap<Value, Value> valueMapping;
1267  for (Operation *op : ops) {
1269  .Case([&](vector::TransferReadOp transferReadOp) {
1270  return convertTransferReadToLoads(rewriter, transferReadOp,
1271  valueMapping);
1272  })
1273  .Case([&](vector::TransferWriteOp transferWriteOp) {
1274  return convertTransferWriteToStores(rewriter, transferWriteOp,
1275  valueMapping);
1276  })
1277  .Case([&](vector::ExtractStridedSliceOp extractStridedSliceOp) {
1278  return convertExtractStridedSlice(rewriter, extractStridedSliceOp,
1279  valueMapping);
1280  })
1281  .Case([&](vector::ContractionOp contractionOp) {
1282  return convertContractOpToMmaSync(rewriter, contractionOp,
1283  valueMapping);
1284  })
1285  .Case([&](scf::ForOp forOp) {
1286  return convertForOp(rewriter, forOp, valueMapping);
1287  })
1288  .Case([&](scf::YieldOp yieldOp) {
1289  return convertYieldOp(rewriter, yieldOp, valueMapping);
1290  })
1291  .Case([&](arith::ConstantOp constOp) {
1292  return convertConstantOpMmaSync(rewriter, constOp, valueMapping);
1293  })
1294  .Default([&](Operation *op) {
1295  return op->emitError() << "unhandled vector to mma type: " << *op;
1296  })
1297  .failed()) {
1298  return op->emitOpError()
1299  << "failed to convert op during vector-to-nvgpu conversion";
1300  }
1301  }
1302  return success();
1303 }
1304 
1305 namespace {
1306 
1307 struct ConvertVectorToGPUPass
1308  : public impl::ConvertVectorToGPUBase<ConvertVectorToGPUPass> {
1309 
1310  explicit ConvertVectorToGPUPass(bool useNvGpu_) {
1311  useNvGpu.setValue(useNvGpu_);
1312  }
1313 
1314  void runOnOperation() override {
1315  RewritePatternSet patterns(&getContext());
1316  populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue());
1317  if (failed(
1318  applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
1319  return signalPassFailure();
1320 
1321  IRRewriter rewriter(&getContext());
1322  if (useNvGpu) {
1323  if (failed(
1324  convertVectorToNVVMCompatibleMMASync(rewriter, getOperation())))
1325  return signalPassFailure();
1326  return;
1327  }
1328  (void)convertVectorToMMAOps(rewriter, getOperation());
1329  }
1330 };
1331 
1332 } // namespace
1333 
1334 std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass(bool useNvGpu) {
1335  return std::make_unique<ConvertVectorToGPUPass>(useNvGpu);
1336 }
static MLIRContext * getContext(OpFoldResult val)
#define SUBI(lhs, rhs)
Definition: LoopEmitter.cpp:37
#define MULI(lhs, rhs)
Definition: LoopEmitter.cpp:38
#define ADDI(lhs, rhs)
Definition: LoopEmitter.cpp:35
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
static LogicalResult convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertContractOp(RewriterBase &rewriter, vector::ContractionOp op, llvm::DenseMap< Value, Value > &valueMapping)
static SetVector< Operation * > getSliceContract(Operation *op, BackwardSliceOptions backwardSliceOptions, ForwardSliceOptions forwardSliceOptions)
Return an unsorted slice handling scf.for region differently than getSlice.
static LogicalResult convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool isSharedMemory(MemRefType type)
Return true if this is a shared memory memref type.
static VectorType getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo &regInfo)
Returns the vector type which represents a matrix fragment.
static const char * inferFragType(Operation *op)
static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp)
static bool supportsMMaMatrixType(Operation *op, bool useNvGpu)
static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp)
Return true if the constant is a splat to a 2D vector so that it can be converted to a MMA constant m...
static bool contractSupportsMMAMatrixType(vector::ContractionOp contract, bool useNvGpu)
Definition: VectorToGPU.cpp:77
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
static bool isTransposeMatrixLoadMap(AffineMap permutationMap)
static LogicalResult convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp)
Return true if this integer extend op can be folded into a contract op.
static LogicalResult convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
Converts a vector.transfer_read operation directly to either a vector.load or a nvgpu....
static LogicalResult convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
static LogicalResult creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
static FailureOr< bool > isTransposed(vector::TransferReadOp op)
Check if the loaded matrix operand requires transposed.
static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, scf::ForOp loop, ValueRange newInitArgs)
static bool transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp)
static std::optional< gpu::MMAElementwiseOp > convertElementwiseOpToMMA(Operation *op)
Return the MMA elementwise enum associated with op if it is supported.
static LogicalResult convertElementwiseOp(RewriterBase &rewriter, Operation *op, gpu::MMAElementwiseOp opType, llvm::DenseMap< Value, Value > &valueMapping)
Convert an elementwise op to the equivalent elementwise op on MMA matrix.
static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp)
static bool extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op)
Returns true if the extract strided slice op is supported with mma.sync path.
static LogicalResult convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static bool elementwiseSupportsMMAMatrixType(Operation *op)
Return true if the op is supported as elementwise op on MMAMatrix type.
static SetVector< Operation * > getOpToConvert(mlir::Operation *op, bool useNvGpu)
static LogicalResult convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static LogicalResult convertYieldOp(RewriterBase &rewriter, scf::YieldOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op, llvm::DenseMap< Value, Value > &valueMapping)
#define DBGS()
Definition: VectorToGPU.cpp:40
static std::optional< int64_t > getStaticallyKnownRowStride(ShapedType type)
static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp, AffineMap offsetMap, ArrayRef< Value > dimValues, SmallVector< Value, 4 > &indices)
For a vector TransferOpType xferOp, an empty indices vector, and an AffineMap representing offsets to...
Definition: VectorToGPU.cpp:57
static LogicalResult convertExtractStridedSlice(RewriterBase &rewriter, vector::ExtractStridedSliceOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp)
Return true if this is a broadcast from scalar to a 2D vector.
static LogicalResult createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
MLIRContext * getContext() const
Definition: AffineMap.cpp:321
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
Definition: AffineMap.cpp:152
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
Definition: AffineMap.cpp:374
unsigned getNumResults() const
Definition: AffineMap.cpp:382
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:292
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:391
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:248
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:536
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:315
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:122
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
UnitAttr getUnitAttr()
Definition: Builders.cpp:114
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:361
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:353
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:288
IndexType getIndexType()
Definition: Builders.cpp:71
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:325
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:710
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:333
This class helps build Operations.
Definition: Builders.h:206
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:383
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:828
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:776
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:852
user_iterator user_begin()
Definition: Operation.h:848
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:640
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:538
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:660
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:615
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition: GPUDialect.h:129
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Definition: GPUDialect.cpp:120
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1124
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
int64_t inferTileWidthInBits(const WarpMatrixInfo &type)
Returns the number of bits in a single tile row.
Definition: MMAUtils.cpp:87
FailureOr< vector::ContractionOp > getUserContract(Operation *op)
Returns the first user of the op that is vector.contract.
Definition: MMAUtils.cpp:50
FailureOr< AffineMap > getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc, const WarpMatrixInfo &fragmentType)
Returns an AffineMap which maps a two dimensions representing (laneId, logicalValueId) and returns tw...
Definition: MMAUtils.cpp:173
FailureOr< WarpMatrixInfo > getWarpMatrixInfo(Operation *op)
If op is a vector.transfer_write, return the WarpMatrixInfo for the vector operand.
Definition: MMAUtils.cpp:58
FailureOr< AffineMap > getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc, const LdMatrixParams &params)
Returns an AffineMap which maps a single dimension representing the laneId to two results representin...
Definition: MMAUtils.cpp:238
FailureOr< LdMatrixParams > getLdMatrixParams(const WarpMatrixInfo &type, bool transpose)
Given type that contains info for a warp-matrix operand and whether or not the load is a transposed l...
Definition: MMAUtils.cpp:209
FailureOr< FragmentElementInfo > getMmaSyncRegisterType(const WarpMatrixInfo &type)
Returns a FragmentElementInfo struct describing the register types for the given matrix fragment type...
Definition: MMAUtils.cpp:100
bool canLowerToWarpMatrixOperation(vector::TransferReadOp op)
Returns whether the vector.transfer_read instruction can be interpreted as a warp-level cooperative m...
Definition: MMAUtils.cpp:276
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Definition: VectorOps.h:137
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Definition: VectorOps.h:132
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction wit...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, bool useNvGpu=false)
Patterns to transform vector ops into a canonical form to convert to MMA matrix operations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:334
void getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter, Operation *rootOp)
Convert vector ops ops nested under rootOp to vector and GPU operaitons compatible with the nvvm....
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:608
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::unique_ptr< Pass > createConvertVectorToGPUPass(bool useNvGpu=false)
Convert from vector to GPU ops.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:584
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
LogicalResult convertVectorToMMAOps(RewriterBase &rewriter, Operation *rootOp)
Convert vector ops to MMA matrix operations nested under rootOp.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Multi-root DAG topological sort.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
static LogicalResult success(bool isSuccess=true)
If isSuccess is true a success result is generated, otherwise a 'failure' result is generated.
Definition: LogicalResult.h:30
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
TransitiveFilter filter
Definition: SliceAnalysis.h:30
Specifies information about the registers which compose a matrix fragment according to the PTX docume...
Definition: MMAUtils.h:52