MLIR  21.0.0git
VectorToGPU.cpp
Go to the documentation of this file.
1 //===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of vector operations to GPU dialect ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
28 #include "mlir/IR/Builders.h"
29 #include "mlir/IR/Region.h"
30 #include "mlir/Pass/Pass.h"
32 #include "llvm/ADT/STLExtras.h"
33 #include "llvm/ADT/TypeSwitch.h"
34 
35 #define DEBUG_TYPE "vector-to-gpu"
36 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
37 #define DBGSNL() (llvm::dbgs() << "\n")
38 
39 namespace mlir {
40 #define GEN_PASS_DEF_CONVERTVECTORTOGPU
41 #include "mlir/Conversion/Passes.h.inc"
42 } // namespace mlir
43 
44 using namespace mlir;
45 
46 /// For a vector TransferOpType `xferOp`, an empty `indices` vector, and an
47 /// AffineMap representing offsets to apply to indices, the function fills
48 /// `indices` with the original indices plus the offsets. The offsets are
49 /// applied by taking into account the permutation map of the transfer op. If
50 /// the `offsetMap` has dimension placeholders, those should be provided in
51 /// `dimValues`.
52 template <typename TransferOpType>
53 static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp,
54  AffineMap offsetMap, ArrayRef<Value> dimValues,
55  SmallVector<Value, 4> &indices) {
56  indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end());
57  Location loc = xferOp.getLoc();
58  unsigned offsetsIdx = 0;
59  for (auto expr : xferOp.getPermutationMap().getResults()) {
60  if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
61  Value prevIdx = indices[dim.getPosition()];
62  SmallVector<OpFoldResult, 3> dims(dimValues);
63  dims.push_back(prevIdx);
64  AffineExpr d0 = rewriter.getAffineDimExpr(offsetMap.getNumDims());
65  indices[dim.getPosition()] = affine::makeComposedAffineApply(
66  rewriter, loc, d0 + offsetMap.getResult(offsetsIdx++), dims);
67  continue;
68  }
69  }
70 }
71 
72 // Return true if the contract op can be convert to MMA matmul.
73 static bool contractSupportsMMAMatrixType(vector::ContractionOp contract,
74  bool useNvGpu) {
75  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
76  auto infer = [&](MapList m) {
77  return AffineMap::inferFromExprList(m, contract.getContext());
78  };
79  AffineExpr m, n, k;
80  bindDims(contract.getContext(), m, n, k);
81  auto iteratorTypes = contract.getIteratorTypes().getValue();
82  if (!(vector::isParallelIterator(iteratorTypes[0]) &&
83  vector::isParallelIterator(iteratorTypes[1]) &&
84  vector::isReductionIterator(iteratorTypes[2])))
85  return false;
86 
87  // The contract needs to represent a matmul to be able to convert to
88  // MMAMatrix matmul.
89  if (!useNvGpu &&
90  contract.getIndexingMapsArray() != infer({{m, k}, {k, n}, {m, n}}))
91  return false;
92  if (useNvGpu &&
93  contract.getIndexingMapsArray() != infer({{m, k}, {n, k}, {m, n}}))
94  return false;
95 
96  return true;
97 }
98 
99 // Return true if the given map represents a transposed matrix load,
100 // i.e. (d0, d1, ...) -> (dn-1, dn-2).
101 static bool isTransposeMatrixLoadMap(AffineMap permutationMap) {
102  MLIRContext *ctx = permutationMap.getContext();
103  // Local OpBuilder is fine here, we just build attributes.
104  OpBuilder b(ctx);
105  auto nDim = permutationMap.getNumDims();
106  AffineExpr zero = b.getAffineConstantExpr(0);
107  if (nDim < 2) {
108  // Support transposed+broadcasted cases: affine_map<(d0) -> (d0, 0)>.
109  AffineExpr dim0 = b.getAffineDimExpr(0);
110  return permutationMap == AffineMap::get(1, 0, {dim0, zero}, ctx);
111  }
112 
113  AffineExpr innerDim = b.getAffineDimExpr(nDim - 1);
114  AffineExpr outerDim = b.getAffineDimExpr(nDim - 2);
115  // Support both transposed and transposed+broadcasted cases.
116  return permutationMap == AffineMap::get(nDim, 0, {innerDim, outerDim}, ctx) ||
117  permutationMap == AffineMap::get(nDim, 0, {innerDim, zero}, ctx);
118 }
119 
120 // Return the stide for the second-to-last dimension of |type| if it is a memref
121 // and has a constant stride.
122 static std::optional<int64_t> getStaticallyKnownRowStride(ShapedType type) {
123  auto memrefType = dyn_cast<MemRefType>(type);
124  if (!memrefType)
125  return false;
126  // If the memref is 0 or 1D the horizontal stride is 0.
127  if (memrefType.getRank() < 2)
128  return 0;
129  int64_t offset = 0;
130  SmallVector<int64_t, 2> strides;
131  if (failed(memrefType.getStridesAndOffset(strides, offset)) ||
132  strides.back() != 1)
133  return std::nullopt;
134  int64_t stride = strides[strides.size() - 2];
135  if (stride == ShapedType::kDynamic)
136  return std::nullopt;
137  return stride;
138 }
139 
140 // Return true if the transfer op can be converted to a MMA matrix load.
141 static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
142  if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
143  readOp.getVectorType().getRank() != 2)
144  return false;
145  if (!getStaticallyKnownRowStride(readOp.getShapedType()))
146  return false;
147 
148  // Only allow integer types if the signedness can be inferred.
149  if (readOp.getVectorType().getElementType().isInteger(8))
150  if (!readOp->hasOneUse() || (!isa<arith::ExtSIOp>(*readOp->user_begin()) &&
151  !isa<arith::ExtUIOp>(*readOp->user_begin())))
152  return false;
153 
154  AffineMap map = readOp.getPermutationMap();
155  MLIRContext *ctx = readOp.getContext();
156  AffineExpr innerDim = getAffineDimExpr(map.getNumDims() - 1, ctx);
157  AffineExpr zero = getAffineConstantExpr(0, ctx);
158  auto broadcastInnerDim =
159  AffineMap::get(map.getNumDims(), 0, {zero, innerDim}, ctx);
160  return map.isMinorIdentity() || map == broadcastInnerDim ||
162 }
163 
164 // Return true if the transfer op can be converted to a MMA matrix store.
165 static bool
166 transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
167  // TODO: support 0-d corner case.
168  if (writeOp.getTransferRank() == 0)
169  return false;
170 
171  if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() ||
172  writeOp.getVectorType().getRank() != 2)
173  return false;
174  if (!getStaticallyKnownRowStride(writeOp.getShapedType()))
175  return false;
176  // TODO: Support transpose once it is added to GPU dialect ops.
177  if (!writeOp.getPermutationMap().isMinorIdentity())
178  return false;
179  return true;
180 }
181 
182 /// Return true if the constant is a splat to a 2D vector so that it can be
183 /// converted to a MMA constant matrix op.
184 static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) {
185  auto vecType = dyn_cast<VectorType>(constantOp.getType());
186  if (!vecType || vecType.getRank() != 2)
187  return false;
188  return isa<SplatElementsAttr>(constantOp.getValue());
189 }
190 
191 /// Return true if this is a broadcast from scalar to a 2D vector.
192 static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {
193  return broadcastOp.getResultVectorType().getRank() == 2;
194 }
195 
196 /// Return true if this integer extend op can be folded into a contract op.
197 template <typename ExtOpTy>
198 static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) {
199  auto transferReadOp =
200  extOp.getOperand().template getDefiningOp<vector::TransferReadOp>();
201  if (!transferReadOp)
202  return false;
203  return llvm::all_of(extOp->getUsers(), llvm::IsaPred<vector::ContractionOp>);
204 }
205 
206 static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp) { return true; }
207 
208 /// Return the MMA elementwise enum associated with `op` if it is supported.
209 /// Return `std::nullopt` otherwise.
210 static std::optional<gpu::MMAElementwiseOp>
212  if (isa<arith::AddFOp>(op))
213  return gpu::MMAElementwiseOp::ADDF;
214  if (isa<arith::MulFOp>(op))
215  return gpu::MMAElementwiseOp::MULF;
216  if (isa<arith::SubFOp>(op))
217  return gpu::MMAElementwiseOp::SUBF;
218  if (isa<arith::MaximumFOp>(op))
219  return gpu::MMAElementwiseOp::MAXF;
220  if (isa<arith::MinimumFOp>(op))
221  return gpu::MMAElementwiseOp::MINF;
222  if (isa<arith::DivFOp>(op))
223  return gpu::MMAElementwiseOp::DIVF;
224  if (isa<arith::AddIOp>(op))
226  if (isa<arith::MulIOp>(op))
228  if (isa<arith::SubIOp>(op))
230  if (isa<arith::DivSIOp>(op))
231  return gpu::MMAElementwiseOp::DIVS;
232  if (isa<arith::DivUIOp>(op))
233  return gpu::MMAElementwiseOp::DIVU;
234  if (isa<arith::NegFOp>(op))
235  return gpu::MMAElementwiseOp::NEGATEF;
236  if (isa<arith::ExtFOp>(op))
237  return gpu::MMAElementwiseOp::EXTF;
238  return std::nullopt;
239 }
240 
241 /// Return true if the op is supported as elementwise op on MMAMatrix type.
243  return convertElementwiseOpToMMA(op).has_value();
244 }
245 
246 /// Returns true if the extract strided slice op is supported with `mma.sync`
247 /// path.
248 static bool
249 extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op) {
250 
251  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
253  if (failed(warpMatrixInfo))
254  return false;
255 
256  FailureOr<vector::ContractionOp> contractOp = nvgpu::getUserContract(op);
257  if (failed(contractOp))
258  return false;
259 
260  // Handle vector.extract_strided_slice on registers containing
261  // matrixB and matrixC operands. vector.extract_strided_slice op
262  // is not supported on registers containing matrixA operands.
263  if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B)
264  return (cast<VectorType>(op->getResult(0).getType()) ==
265  cast<VectorType>((*contractOp).getRhs().getType()));
266  if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::C)
267  return (cast<VectorType>(op->getResult(0).getType()) ==
268  cast<VectorType>((*contractOp).getAcc().getType()));
269 
270  return false;
271 }
272 
273 static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
274  if (isa<scf::ForOp, scf::YieldOp>(op))
275  return true;
276  if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
277  return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferRead)
278  : transferReadSupportsMMAMatrixType(transferRead);
279  if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
280  return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferWrite)
281  : transferWriteSupportsMMAMatrixType(transferWrite);
282  if (auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op))
283  return useNvGpu &&
284  extractStridedSliceSupportsMMAMatrixType(extractStridedSlice);
285  if (auto contract = dyn_cast<vector::ContractionOp>(op))
286  return contractSupportsMMAMatrixType(contract, useNvGpu);
287  if (auto constant = dyn_cast<arith::ConstantOp>(op))
288  return constantSupportsMMAMatrixType(constant);
289  if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
291  if (auto signedExtend = dyn_cast<arith::ExtSIOp>(op))
292  return integerExtendSupportsMMAMatrixType<arith::ExtSIOp>(signedExtend);
293  if (auto unsignedExtend = dyn_cast<arith::ExtUIOp>(op))
294  return integerExtendSupportsMMAMatrixType<arith::ExtUIOp>(unsignedExtend);
295  if (auto fpExtend = dyn_cast<arith::ExtFOp>(op))
296  return fpExtendSupportsMMAMatrixType(fpExtend);
298 }
299 
300 /// Return an unsorted slice handling scf.for region differently than
301 /// `getSlice`. In scf.for we only want to include as part of the slice elements
302 /// that are part of the use/def chain.
305  const BackwardSliceOptions &backwardSliceOptions,
306  const ForwardSliceOptions &forwardSliceOptions) {
308  slice.insert(op);
309  unsigned currentIndex = 0;
310  SetVector<Operation *> backwardSlice;
311  SetVector<Operation *> forwardSlice;
312  while (currentIndex != slice.size()) {
313  auto *currentOp = (slice)[currentIndex];
314  // Compute and insert the backwardSlice starting from currentOp.
315  backwardSlice.clear();
316  LogicalResult result =
317  getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
318  assert(result.succeeded() && "expected a backward slice");
319  (void)result;
320  slice.insert_range(backwardSlice);
321 
322  // Compute and insert the forwardSlice starting from currentOp.
323  forwardSlice.clear();
324  // Special case for ForOp, we don't want to include the whole region but
325  // only the value using the region arguments.
326  // TODO: We should refine this to only care about the region arguments being
327  // converted to matrix type.
328  if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) {
329  for (Value forOpResult : forOp.getResults())
330  getForwardSlice(forOpResult, &forwardSlice, forwardSliceOptions);
331  for (BlockArgument &arg : forOp.getRegionIterArgs())
332  getForwardSlice(arg, &forwardSlice, forwardSliceOptions);
333  } else {
334  getForwardSlice(currentOp, &forwardSlice, forwardSliceOptions);
335  }
336  slice.insert_range(forwardSlice);
337  ++currentIndex;
338  }
339  return slice;
340 }
341 
342 // Analyze slice of operations based on convert op to figure out if the whole
343 // slice can be converted to MMA operations.
345  bool useNvGpu) {
346  auto hasVectorDest = [](Operation *op) {
347  return llvm::any_of(op->getResultTypes(), llvm::IsaPred<VectorType>);
348  };
349  BackwardSliceOptions backwardSliceOptions;
350  backwardSliceOptions.filter = hasVectorDest;
351 
352  auto hasVectorSrc = [](Operation *op) {
353  return llvm::any_of(op->getOperandTypes(), llvm::IsaPred<VectorType>);
354  };
355  ForwardSliceOptions forwardSliceOptions;
356  forwardSliceOptions.filter = hasVectorSrc;
357 
358  SetVector<Operation *> opToConvert;
359  op->walk([&](vector::ContractionOp contract) {
360  if (opToConvert.contains(contract.getOperation()))
361  return;
362  SetVector<Operation *> dependentOps =
363  getSliceContract(contract, backwardSliceOptions, forwardSliceOptions);
364  // If any instruction cannot use MMA matrix type drop the whole
365  // chain. MMA matrix are stored in an opaque type so they cannot be used
366  // by all operations.
367  if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) {
368  if (!supportsMMaMatrixType(op, useNvGpu)) {
369  LLVM_DEBUG(DBGS() << "cannot convert op: " << *op << "\n");
370  return true;
371  }
372  return false;
373  }))
374  return;
375 
376  opToConvert.insert_range(dependentOps);
377  });
378  // Sort the operations so that we can convert them in topological order.
379  return topologicalSort(opToConvert);
380 }
381 
382 namespace {
383 // Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted
384 // to MMA matmul.
385 struct PrepareContractToGPUMMA
386  : public OpRewritePattern<vector::ContractionOp> {
388 
389  LogicalResult matchAndRewrite(vector::ContractionOp op,
390  PatternRewriter &rewriter) const override {
391  Location loc = op.getLoc();
392  Value lhs = op.getLhs(), rhs = op.getRhs(), res = op.getAcc();
393 
394  // Set up the parallel/reduction structure in right form.
395  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
396  auto infer = [&](MapList m) {
397  return AffineMap::inferFromExprList(m, op.getContext());
398  };
399  AffineExpr m, n, k;
400  bindDims(rewriter.getContext(), m, n, k);
401  static constexpr std::array<int64_t, 2> perm = {1, 0};
402  auto iteratorTypes = op.getIteratorTypes().getValue();
403  SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
404  if (!(vector::isParallelIterator(iteratorTypes[0]) &&
405  vector::isParallelIterator(iteratorTypes[1]) &&
406  vector::isReductionIterator(iteratorTypes[2])))
407  return rewriter.notifyMatchFailure(op, "not a gemm contraction");
408  //
409  // Two outer parallel, one inner reduction (matmat flavor).
410  //
411  // This is the classical row-major matmul, nothing to do.
412  if (maps == infer({{m, k}, {k, n}, {m, n}}))
413  return rewriter.notifyMatchFailure(op, "contraction already prepared");
414  if (maps == infer({{m, k}, {n, k}, {m, n}})) {
415  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
416  } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
417  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
418  } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
419  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
420  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
421  } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
422  std::swap(rhs, lhs);
423  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
424  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
425  } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
426  std::swap(rhs, lhs);
427  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
428  } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
429  std::swap(lhs, rhs);
430  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
431  } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
432  std::swap(lhs, rhs);
433  } else {
434  // TODO: llvm_unreachable ?
435  return rewriter.notifyMatchFailure(op, "unexpected contraction case");
436  }
437  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
438  op, lhs, rhs, res,
439  rewriter.getAffineMapArrayAttr(infer({{m, k}, {k, n}, {m, n}})),
440  op.getIteratorTypes());
441  return success();
442  }
443 };
444 
445 // Fold transpose op into the transfer read op. NVGPU mma.sync op only supports
446 // row-, column-, and row-major layout for matrixA, matrixB, and matrixC,
447 // respectively. We can fold the transpose operation when loading the data from
448 // Shared Memory to registers.
449 struct CombineTransferReadOpTranspose final
450  : public OpRewritePattern<vector::TransposeOp> {
452 
453  LogicalResult matchAndRewrite(vector::TransposeOp op,
454  PatternRewriter &rewriter) const override {
455  // Look through integer extend ops.
456  Value source = op.getVector();
457  Type resultType = op.getType();
458  Operation *extOp;
459  if ((extOp = source.getDefiningOp<arith::ExtSIOp>()) ||
460  (extOp = source.getDefiningOp<arith::ExtUIOp>()) ||
461  (extOp = source.getDefiningOp<arith::ExtFOp>())) {
462  source = extOp->getOperand(0);
463  resultType =
464  VectorType::get(cast<VectorType>(resultType).getShape(),
465  cast<VectorType>(source.getType()).getElementType());
466  }
467 
468  auto transferReadOp = source.getDefiningOp<vector::TransferReadOp>();
469  if (!transferReadOp)
470  return rewriter.notifyMatchFailure(op, "no transfer read");
471 
472  // TODO: support 0-d corner case.
473  if (transferReadOp.getTransferRank() == 0)
474  return rewriter.notifyMatchFailure(op, "0-D transfer read");
475 
476  if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim())
477  return rewriter.notifyMatchFailure(op, "not inbounds transfer read");
478 
479  AffineMap permutationMap =
480  AffineMap::getPermutationMap(op.getPermutation(), op.getContext());
481  AffineMap newMap =
482  permutationMap.compose(transferReadOp.getPermutationMap());
483 
484  auto loc = op.getLoc();
485  Value result =
486  rewriter
487  .create<vector::TransferReadOp>(
488  loc, resultType, transferReadOp.getBase(),
489  transferReadOp.getIndices(), AffineMapAttr::get(newMap),
490  transferReadOp.getPadding(), transferReadOp.getMask(),
491  transferReadOp.getInBoundsAttr())
492  .getResult();
493 
494  // Fuse through the integer extend op.
495  if (extOp) {
496  if (isa<arith::ExtSIOp>(extOp))
497  result = rewriter.create<arith::ExtSIOp>(loc, op.getType(), result)
498  .getResult();
499  else if (isa<arith::ExtUIOp>(extOp))
500  result = rewriter.create<arith::ExtUIOp>(loc, op.getType(), result)
501  .getResult();
502  else
503  result = rewriter.create<arith::ExtFOp>(loc, op.getType(), result)
504  .getResult();
505  }
506 
507  rewriter.replaceOp(op, result);
508  return success();
509  }
510 };
511 
512 } // namespace
513 
514 // MMA types have different layout based on how they are used in matmul ops.
515 // Figure the right layout to use by looking at op uses.
516 // TODO: Change the GPU dialect to abstract the layout at the this level and
517 // only care about it during lowering to NVVM.
518 static const char *inferFragType(Operation *op) {
519  // We can have arith.ext ops before reaching contract ops. See through them
520  // and other kinds of elementwise ops.
521  if (op->hasOneUse()) {
522  Operation *userOp = *op->user_begin();
523  if (userOp->hasTrait<OpTrait::Elementwise>())
524  return inferFragType(userOp);
525  }
526 
527  for (Operation *users : op->getUsers()) {
528  auto contract = dyn_cast<vector::ContractionOp>(users);
529  if (!contract)
530  continue;
531  assert(op->getNumResults() == 1);
532  if (contract.getLhs() == op->getResult(0))
533  return "AOp";
534  if (contract.getRhs() == op->getResult(0))
535  return "BOp";
536  }
537  return "COp";
538 }
539 
540 static LogicalResult
541 convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
542  llvm::DenseMap<Value, Value> &valueMapping) {
543  OpBuilder::InsertionGuard g(rewriter);
544  rewriter.setInsertionPoint(op);
545 
546  assert(op.getTransferRank() > 0 && "unexpected 0-d transfer");
548  "expected convertible operation");
549 
550  std::optional<int64_t> stride =
551  getStaticallyKnownRowStride(op.getShapedType());
552  if (!stride.has_value()) {
553  LLVM_DEBUG(DBGS() << "no stride\n");
554  return rewriter.notifyMatchFailure(op, "no stride");
555  }
556 
557  AffineMap map = op.getPermutationMap();
558  bool isTranspose = isTransposeMatrixLoadMap(map);
559 
560  // Handle broadcast by setting the stride to 0.
561  if (auto cstExpr = dyn_cast<AffineConstantExpr>(map.getResult(isTranspose))) {
562  assert(cstExpr.getValue() == 0);
563  stride = 0;
564  }
565 
566  Value mappingResult = op.getResult();
567  auto elType = op.getVectorType().getElementType();
568  const char *fragType = inferFragType(op);
569  if (op->hasOneUse()) {
570  auto *user = *op->user_begin();
571  // Infer the signedness of the mma type from the integer extend.
572  if (isa<arith::ExtSIOp, arith::ExtUIOp>(user)) {
573  elType = IntegerType::get(
574  op.getContext(), cast<IntegerType>(elType).getWidth(),
575  isa<arith::ExtSIOp>(user) ? IntegerType::Signed
576  : IntegerType::Unsigned);
577  mappingResult = user->getResult(0);
578  }
579  }
580  gpu::MMAMatrixType type =
581  gpu::MMAMatrixType::get(op.getVectorType().getShape(), elType, fragType);
582  Value load = rewriter.create<gpu::SubgroupMmaLoadMatrixOp>(
583  op.getLoc(), type, op.getBase(), op.getIndices(),
584  rewriter.getIndexAttr(*stride),
585  isTranspose ? rewriter.getUnitAttr() : UnitAttr());
586  valueMapping[mappingResult] = load;
587 
588  LLVM_DEBUG(DBGS() << "transfer read to: " << load << "\n");
589  return success();
590 }
591 
592 static LogicalResult
593 convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
594  llvm::DenseMap<Value, Value> &valueMapping) {
595  OpBuilder::InsertionGuard g(rewriter);
596  rewriter.setInsertionPoint(op);
597 
599  std::optional<int64_t> stride =
600  getStaticallyKnownRowStride(op.getShapedType());
601  if (!stride.has_value()) {
602  LLVM_DEBUG(DBGS() << "no stride\n");
603  return rewriter.notifyMatchFailure(op, "no stride");
604  }
605 
606  auto it = valueMapping.find(op.getVector());
607  if (it == valueMapping.end()) {
608  LLVM_DEBUG(DBGS() << "no mapping\n");
609  return rewriter.notifyMatchFailure(op, "no mapping");
610  }
611 
612  Value matrix = it->second;
613  auto store = rewriter.create<gpu::SubgroupMmaStoreMatrixOp>(
614  op.getLoc(), matrix, op.getBase(), op.getIndices(),
615  rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr());
616  (void)store;
617 
618  LLVM_DEBUG(DBGS() << "transfer write to: " << store << "\n");
619 
620  LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
621  rewriter.eraseOp(op);
622  return success();
623 }
624 
625 /// Returns the vector type which represents a matrix fragment.
626 static VectorType
629  regInfo.elementsPerRegister};
630  Type elType = regInfo.registerLLVMType;
631  if (auto vecType = dyn_cast<VectorType>(elType))
632  elType = vecType.getElementType();
633  return VectorType::get(shape, elType);
634 }
635 
636 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
637 static LogicalResult
638 convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op,
639  llvm::DenseMap<Value, Value> &valueMapping) {
640  OpBuilder::InsertionGuard g(rewriter);
641  rewriter.setInsertionPoint(op);
642 
643  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
645  if (failed(warpMatrixInfo)) {
646  LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
647  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
648  }
649 
650  FailureOr<nvgpu::FragmentElementInfo> regInfo =
651  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
652  if (failed(regInfo)) {
653  LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
654  return rewriter.notifyMatchFailure(op, "not mma sync reg info");
655  }
656 
657  VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
658  auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
659  if (!dense) {
660  LLVM_DEBUG(DBGS() << "not a splat\n");
661  return rewriter.notifyMatchFailure(op, "not a splat");
662  }
663 
664  Value result = rewriter.create<arith::ConstantOp>(
665  op.getLoc(), vectorType,
666  DenseElementsAttr::get(vectorType, dense.getSplatValue<Attribute>()));
667  valueMapping[op.getResult()] = result;
668  return success();
669 }
670 
671 /// Check if the loaded matrix operand requires transposed.
672 /// Transposed Map Example:
673 /// Example 1 : (..., d0, d1) -> (d1 * 1, d0 * 2)
674 /// Example 2 : (d0, d1, d2, d3) -> (d3, d2)
675 /// The code below checks if the output 2D is transposed using a generalized
676 /// version : (d0, d1, dn, ..., dm, ...) -> (dm, dn)
677 /// Returns : true; if m > n, false o.w.
678 static FailureOr<bool> isTransposed(vector::TransferReadOp op) {
680 
681  if (map.getNumResults() != 2) {
682  LLVM_DEBUG(DBGS() << "Failed because the result of `vector.transfer_read` "
683  "is not a 2d operand\n");
684  return failure();
685  }
686 
687  // Output 2D matrix dimensions in the order of d0, d1.
688  mlir::AffineExpr dM = map.getResult(0);
689  mlir::AffineExpr dN = map.getResult(1);
690 
691  // Find the position of these expressions in the input.
692  auto exprM = dyn_cast<AffineDimExpr>(dM);
693  auto exprN = dyn_cast<AffineDimExpr>(dN);
694 
695  if (!exprM || !exprN) {
696  LLVM_DEBUG(DBGS() << "Failed because expressions are not affine dim "
697  "expressions, then transpose cannot be determined.\n");
698  return failure();
699  }
700 
701  return exprM.getPosition() > exprN.getPosition();
702 }
703 
704 static LogicalResult
705 creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
706  llvm::DenseMap<Value, Value> &valueMapping) {
707  OpBuilder::InsertionGuard g(rewriter);
708  rewriter.setInsertionPoint(op);
709  Location loc = op->getLoc();
710 
711  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
713  if (failed(warpMatrixInfo)) {
714  LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
715  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
716  }
717 
718  FailureOr<nvgpu::FragmentElementInfo> regInfo =
719  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
720  if (failed(regInfo)) {
721  LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
722  return rewriter.notifyMatchFailure(op, "not mma sync reg info");
723  }
724 
725  FailureOr<bool> transpose = isTransposed(op);
726  if (failed(transpose)) {
727  LLVM_DEBUG(DBGS() << "failed to determine the transpose\n");
728  return rewriter.notifyMatchFailure(
729  op, "Op should likely not be converted to a nvgpu.ldmatrix call.");
730  }
731 
732  FailureOr<nvgpu::LdMatrixParams> params =
733  nvgpu::getLdMatrixParams(*warpMatrixInfo, *transpose);
734 
735  if (failed(params)) {
736  LLVM_DEBUG(
737  DBGS()
738  << "failed to convert vector.transfer_read to ldmatrix. "
739  << "Op should likely not be converted to a nvgpu.ldmatrix call.\n");
740  return rewriter.notifyMatchFailure(
741  op, "failed to convert vector.transfer_read to ldmatrix; this op "
742  "likely should not be converted to a nvgpu.ldmatrix call.");
743  }
744 
745  // Adjust the load offset.
746  auto laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr);
747  FailureOr<AffineMap> offsets =
748  nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params);
749  if (failed(offsets)) {
750  LLVM_DEBUG(DBGS() << "no offsets\n");
751  return rewriter.notifyMatchFailure(op, "no offsets");
752  }
753 
754  VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
755 
756  SmallVector<Value, 4> indices;
757  getXferIndices<vector::TransferReadOp>(rewriter, op, *offsets, {laneId},
758  indices);
759 
760  nvgpu::LdMatrixOp newOp = rewriter.create<nvgpu::LdMatrixOp>(
761  loc, vectorType, op.getBase(), indices, *transpose, params->numTiles);
762  valueMapping[op] = newOp->getResult(0);
763  return success();
764 }
765 
766 static LogicalResult
767 createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
768  llvm::DenseMap<Value, Value> &valueMapping) {
769  OpBuilder::InsertionGuard g(rewriter);
770  rewriter.setInsertionPoint(op);
771 
772  Location loc = op.getLoc();
773  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
775  if (failed(warpMatrixInfo))
776  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
777  FailureOr<nvgpu::FragmentElementInfo> regInfo =
778  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
779  if (failed(regInfo)) {
780  return rewriter.notifyMatchFailure(
781  op, "Failed to deduce register fragment type during "
782  "conversion to distributed non-ldmatrix compatible load");
783  }
784 
785  Value laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr);
786 
787  // This is the individual element type.
788  Type loadedElType = regInfo->registerLLVMType;
789  VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
790 
791  Value fill = rewriter.create<arith::ConstantOp>(
792  op.getLoc(), vectorType.getElementType(),
793  rewriter.getZeroAttr(vectorType.getElementType()));
794  Value result =
795  rewriter.create<vector::SplatOp>(op.getLoc(), fill, vectorType);
796 
797  bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
798 
799  // If we are not transposing, then we can use vectorized loads. Otherwise, we
800  // must load each element individually.
801  if (!isTransposeLoad) {
802  if (!isa<VectorType>(loadedElType)) {
803  loadedElType = VectorType::get({1}, loadedElType);
804  }
805 
806  for (int i = 0; i < vectorType.getShape()[0]; i++) {
807  FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
808  rewriter, op.getLoc(), *warpMatrixInfo);
809  if (failed(coords))
810  return rewriter.notifyMatchFailure(op, "no coords");
811 
812  Value logicalValueId = rewriter.create<arith::ConstantOp>(
813  loc, rewriter.getIndexType(),
814  rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
815  SmallVector<Value, 4> newIndices;
816  getXferIndices<vector::TransferReadOp>(
817  rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
818 
819  Value el = rewriter.create<vector::LoadOp>(loc, loadedElType,
820  op.getBase(), newIndices);
821  result = rewriter.create<vector::InsertOp>(loc, el, result, i);
822  }
823  } else {
824  if (auto vecType = dyn_cast<VectorType>(loadedElType)) {
825  loadedElType = vecType.getElementType();
826  }
827  for (int i = 0; i < vectorType.getShape()[0]; i++) {
828  for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1];
829  innerIdx++) {
830 
831  Value logicalValueId = rewriter.create<arith::ConstantOp>(
832  loc, rewriter.getIndexType(),
833  rewriter.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx));
834  FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
835  rewriter, op.getLoc(), *warpMatrixInfo);
836  if (failed(coords))
837  return rewriter.notifyMatchFailure(op, "no coords");
838 
839  SmallVector<Value, 4> newIndices;
840  getXferIndices<vector::TransferReadOp>(
841  rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
842  Value el = rewriter.create<memref::LoadOp>(op.getLoc(), loadedElType,
843  op.getBase(), newIndices);
844  result = rewriter.create<vector::InsertOp>(
845  op.getLoc(), el, result, ArrayRef<int64_t>{i, innerIdx});
846  }
847  }
848  }
849 
850  valueMapping[op.getResult()] = result;
851  return success();
852 }
853 
854 /// Return true if this is a shared memory memref type.
855 static bool isSharedMemory(MemRefType type) {
856  auto addressSpace =
857  dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
858  return addressSpace &&
859  addressSpace.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
860 }
861 
862 /// Converts a `vector.transfer_read` operation directly to either a
863 /// `vector.load` or a `nvgpu.ldmatrix` operation. This function should only be
864 /// used when converting to `nvgpu.mma.sync` operations.
865 static LogicalResult
866 convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op,
867  llvm::DenseMap<Value, Value> &valueMapping) {
868  OpBuilder::InsertionGuard g(rewriter);
869  rewriter.setInsertionPoint(op);
870 
871  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
873  if (failed(warpMatrixInfo))
874  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
875 
876  bool isLdMatrixCompatible =
877  isSharedMemory(cast<MemRefType>(op.getBase().getType())) &&
878  nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128;
879 
880  VectorType vecTy = op.getVectorType();
881  int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth();
882 
883  // When we are transposing the B operand, ldmatrix will only work if we have
884  // at least 8 rows to read and the width to read for the transpose is 128
885  // bits.
886  if (!op.getPermutationMap().isMinorIdentity() &&
887  (bitWidth != 16 || vecTy.getDimSize(1) < 8 ||
888  vecTy.getDimSize(0) * bitWidth < 128))
889  isLdMatrixCompatible = false;
890 
891  if (!isLdMatrixCompatible)
892  return createNonLdMatrixLoads(rewriter, op, valueMapping);
893 
894  return creatLdMatrixCompatibleLoads(rewriter, op, valueMapping);
895 }
896 
897 static LogicalResult
898 convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
899  llvm::DenseMap<Value, Value> &valueMapping) {
900  OpBuilder::InsertionGuard g(rewriter);
901  rewriter.setInsertionPoint(op);
902 
903  Location loc = op->getLoc();
904  auto it = valueMapping.find(op.getVector());
905  if (it == valueMapping.end())
906  return rewriter.notifyMatchFailure(op, "no mapping");
907  Value matrix = it->second;
908 
909  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
911  if (failed(warpMatrixInfo))
912  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
913  FailureOr<nvgpu::FragmentElementInfo> regInfo =
914  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
915  if (failed(regInfo))
916  return rewriter.notifyMatchFailure(op, "not mma sync reg info");
917 
918  VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
919  Value laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr);
920 
921  for (unsigned i = 0; i < vectorType.getShape()[0]; i++) {
922  Value logicalValueId = rewriter.create<arith::ConstantOp>(
923  loc, rewriter.getIndexType(),
924  rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
925  FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
926  rewriter, op.getLoc(), *warpMatrixInfo);
927  if (failed(coords))
928  return rewriter.notifyMatchFailure(op, "no coords");
929 
930  Value el =
931  rewriter.create<vector::ExtractOp>(loc, matrix, ArrayRef<int64_t>{i});
932  SmallVector<Value, 4> newIndices;
933  getXferIndices<vector::TransferWriteOp>(
934  rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
935  rewriter.create<vector::StoreOp>(loc, el, op.getBase(), newIndices);
936  }
937 
938  LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
939  rewriter.eraseOp(op);
940  return success();
941 }
942 
943 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
944  SmallVectorImpl<int64_t> &results) {
945  for (auto attr : arrayAttr)
946  results.push_back(cast<IntegerAttr>(attr).getInt());
947 }
948 
949 static LogicalResult
951  vector::ExtractStridedSliceOp op,
952  llvm::DenseMap<Value, Value> &valueMapping) {
953  OpBuilder::InsertionGuard g(rewriter);
954  rewriter.setInsertionPoint(op);
955 
956  Location loc = op->getLoc();
957 
958  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
960  if (failed(warpMatrixInfo))
961  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
962 
963  FailureOr<nvgpu::FragmentElementInfo> mmaSyncFragmentInfo =
964  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
965  if (failed(mmaSyncFragmentInfo))
966  return rewriter.notifyMatchFailure(op, "no mmaSyncFragmentInfo");
967 
968  // Find the vector.transer_read whose result vector is being sliced.
969  auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>();
970  if (!transferReadOp)
971  return rewriter.notifyMatchFailure(op, "no transfer read");
972 
973  warpMatrixInfo = nvgpu::getWarpMatrixInfo(transferReadOp);
974  if (failed(warpMatrixInfo))
975  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
976 
977  FailureOr<nvgpu::FragmentElementInfo> ldFragmentInfo =
978  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
979  if (failed(ldFragmentInfo))
980  return rewriter.notifyMatchFailure(op, "no ldFragmentInfo");
981 
982  assert(
983  (mmaSyncFragmentInfo->elementsPerRegister ==
984  ldFragmentInfo->elementsPerRegister) &&
985  "Number of elements per register should be same for load and mma.sync");
986 
987  // Create vector.extract_strided_slice op for thread-owned fragments.
988  std::array<int64_t, 2> strides = {1,
989  1}; // stride for extract slice is always 1.
990  std::array<int64_t, 2> sliceShape = {
991  mmaSyncFragmentInfo->numRegistersPerFragment,
992  mmaSyncFragmentInfo->elementsPerRegister};
993  auto it = valueMapping.find(transferReadOp);
994  if (it == valueMapping.end())
995  return rewriter.notifyMatchFailure(op, "no mapping");
996  auto sourceVector = it->second;
997 
998  // offset and sizes at warp-level of onwership.
999  SmallVector<int64_t> offsets;
1000  populateFromInt64AttrArray(op.getOffsets(), offsets);
1001 
1002  SmallVector<int64_t> sizes;
1003  populateFromInt64AttrArray(op.getSizes(), sizes);
1004  ArrayRef<int64_t> warpVectorShape = op.getSourceVectorType().getShape();
1005 
1006  // Compute offset in vector registers. Note that the mma.sync vector registers
1007  // are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector
1008  // registers can only be sliced along numberOfFragments, i.e., sliceOffset[0].
1009  std::array<int64_t, 2> sliceOffset = {0, 0};
1010 
1011  if (offsets[0] && offsets[1])
1012  return op->emitError() << "Slicing fragments in 2D is not supported. ";
1013  if (offsets[0])
1014  sliceOffset[0] = (warpVectorShape[0] / offsets[0]);
1015  else if (offsets[1])
1016  sliceOffset[0] = (warpVectorShape[1] / offsets[1]);
1017 
1018  Value newOp = rewriter.create<vector::ExtractStridedSliceOp>(
1019  loc, sourceVector, sliceOffset, sliceShape, strides);
1020 
1021  valueMapping[op] = newOp;
1022  return success();
1023 }
1024 
1025 static LogicalResult
1026 convertContractOp(RewriterBase &rewriter, vector::ContractionOp op,
1027  llvm::DenseMap<Value, Value> &valueMapping) {
1028  OpBuilder::InsertionGuard g(rewriter);
1029  rewriter.setInsertionPoint(op);
1030 
1031  auto itA = valueMapping.find(op.getLhs());
1032  auto itB = valueMapping.find(op.getRhs());
1033  auto itC = valueMapping.find(op.getAcc());
1034  if (itA == valueMapping.end() || itB == valueMapping.end() ||
1035  itC == valueMapping.end())
1036  return rewriter.notifyMatchFailure(op, "no mapping");
1037  Value opA = itA->second, opB = itB->second, opC = itC->second;
1038  Value matmul = rewriter.create<gpu::SubgroupMmaComputeOp>(
1039  op.getLoc(), opC.getType(), opA, opB, opC, /*a_transpose=*/UnitAttr(),
1040  /*b_transpose=*/UnitAttr());
1041  valueMapping[op.getResult()] = matmul;
1042  return success();
1043 }
1044 
1045 static LogicalResult
1046 convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op,
1047  llvm::DenseMap<Value, Value> &valueMapping) {
1048  OpBuilder::InsertionGuard g(rewriter);
1049  rewriter.setInsertionPoint(op);
1050 
1051  auto itA = valueMapping.find(op.getLhs());
1052  auto itB = valueMapping.find(op.getRhs());
1053  auto itC = valueMapping.find(op.getAcc());
1054  if (itA == valueMapping.end() || itB == valueMapping.end() ||
1055  itC == valueMapping.end())
1056  return rewriter.notifyMatchFailure(op, "no mapping");
1057  Value opA = itA->second, opB = itB->second, opC = itC->second;
1058  int64_t m = cast<VectorType>(op.getLhs().getType()).getShape()[0];
1059  int64_t n = cast<VectorType>(op.getRhs().getType()).getShape()[0];
1060  int64_t k = cast<VectorType>(op.getLhs().getType()).getShape()[1];
1061  Value matmul = rewriter.create<nvgpu::MmaSyncOp>(
1062  op.getLoc(), opA, opB, opC, rewriter.getI64ArrayAttr({m, n, k}));
1063  valueMapping[op.getResult()] = matmul;
1064  return success();
1065 }
1066 
1067 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
1068 static LogicalResult
1069 convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op,
1070  llvm::DenseMap<Value, Value> &valueMapping) {
1071  OpBuilder::InsertionGuard g(rewriter);
1072  rewriter.setInsertionPoint(op);
1073 
1074  assert(constantSupportsMMAMatrixType(op));
1075 
1076  auto splat =
1077  cast<SplatElementsAttr>(op.getValue()).getSplatValue<TypedAttr>();
1078  auto scalarConstant =
1079  rewriter.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat);
1080  const char *fragType = inferFragType(op);
1081  auto vecType = cast<VectorType>(op.getType());
1083  vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1084  auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(
1085  op.getLoc(), type, scalarConstant);
1086  valueMapping[op.getResult()] = matrix;
1087  return success();
1088 }
1089 
1090 /// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
1091 static LogicalResult
1092 convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op,
1093  llvm::DenseMap<Value, Value> &valueMapping) {
1094  OpBuilder::InsertionGuard g(rewriter);
1095  rewriter.setInsertionPoint(op);
1096 
1097  assert(broadcastSupportsMMAMatrixType(op));
1098 
1099  const char *fragType = inferFragType(op);
1100  auto vecType = op.getResultVectorType();
1102  vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1103  auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(
1104  op.getLoc(), type, op.getSource());
1105  valueMapping[op.getResult()] = matrix;
1106  return success();
1107 }
1108 
1109 // Replace ForOp with a new ForOp with extra operands. The YieldOp is not
1110 // updated and needs to be updated separately for the loop to be correct.
1111 static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
1112  scf::ForOp loop,
1113  ValueRange newInitArgs) {
1114  OpBuilder::InsertionGuard g(rewriter);
1115  rewriter.setInsertionPoint(loop);
1116 
1117  // Create a new loop before the existing one, with the extra operands.
1118  rewriter.setInsertionPoint(loop);
1119  auto operands = llvm::to_vector<4>(loop.getInitArgs());
1120  llvm::append_range(operands, newInitArgs);
1121  scf::ForOp newLoop = rewriter.create<scf::ForOp>(
1122  loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
1123  operands);
1124  rewriter.eraseBlock(newLoop.getBody());
1125 
1126  newLoop.getRegion().getBlocks().splice(
1127  newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks());
1128  for (Value operand : newInitArgs)
1129  newLoop.getBody()->addArgument(operand.getType(), operand.getLoc());
1130 
1131  for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
1132  loop.getNumResults())))
1133  rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
1134 
1135  LLVM_DEBUG(DBGS() << "newLoop now: " << newLoop << "\n");
1136  LLVM_DEBUG(DBGS() << "stripped scf.for: " << loop << "\n");
1137  LLVM_DEBUG(DBGS() << "erase: " << loop);
1138 
1139  rewriter.eraseOp(loop);
1140  return newLoop;
1141 }
1142 
1143 static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op,
1144  llvm::DenseMap<Value, Value> &valueMapping) {
1145  OpBuilder::InsertionGuard g(rewriter);
1146  rewriter.setInsertionPoint(op);
1147 
1148  SmallVector<Value> newOperands;
1150  for (const auto &operand : llvm::enumerate(op.getInitArgs())) {
1151  auto it = valueMapping.find(operand.value());
1152  if (it == valueMapping.end()) {
1153  LLVM_DEBUG(DBGS() << "no value mapping for: " << operand.value() << "\n");
1154  continue;
1155  }
1156  argMapping.push_back(std::make_pair(
1157  operand.index(), op.getInitArgs().size() + newOperands.size()));
1158  newOperands.push_back(it->second);
1159  }
1160 
1161  scf::ForOp newForOp = replaceForOpWithNewSignature(rewriter, op, newOperands);
1162  Block &loopBody = *newForOp.getBody();
1163  for (auto mapping : argMapping) {
1164  valueMapping[newForOp.getResult(mapping.first)] =
1165  newForOp.getResult(mapping.second);
1166  valueMapping[loopBody.getArgument(mapping.first +
1167  newForOp.getNumInductionVars())] =
1168  loopBody.getArgument(mapping.second + newForOp.getNumInductionVars());
1169  }
1170 
1171  LLVM_DEBUG(DBGS() << "scf.for to: " << newForOp << "\n");
1172  return success();
1173 }
1174 
1175 static LogicalResult
1176 convertYieldOp(RewriterBase &rewriter, scf::YieldOp op,
1177  llvm::DenseMap<Value, Value> &valueMapping) {
1178  OpBuilder::InsertionGuard g(rewriter);
1179  rewriter.setInsertionPoint(op);
1180 
1181  auto loop = cast<scf::ForOp>(op->getParentOp());
1182  auto yieldOperands = llvm::to_vector<4>(op.getOperands());
1183  for (const auto &operand : llvm::enumerate(op.getOperands())) {
1184  auto it = valueMapping.find(operand.value());
1185  if (it == valueMapping.end())
1186  continue;
1187  // Replace the yield of old value with the for op argument to make it easier
1188  // to remove the dead code.
1189  yieldOperands[operand.index()] = loop.getInitArgs()[operand.index()];
1190  yieldOperands.push_back(it->second);
1191  }
1192  rewriter.create<scf::YieldOp>(op.getLoc(), yieldOperands);
1193 
1194  LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
1195  rewriter.eraseOp(op);
1196  return success();
1197 }
1198 
1199 /// Convert an elementwise op to the equivalent elementwise op on MMA matrix.
1200 static LogicalResult
1202  gpu::MMAElementwiseOp opType,
1203  llvm::DenseMap<Value, Value> &valueMapping) {
1204  OpBuilder::InsertionGuard g(rewriter);
1205  rewriter.setInsertionPoint(op);
1206 
1207  SmallVector<Value> matrixOperands;
1208  for (Value operand : op->getOperands()) {
1209  auto it = valueMapping.find(operand);
1210  if (it == valueMapping.end())
1211  return rewriter.notifyMatchFailure(op, "no mapping");
1212  matrixOperands.push_back(it->second);
1213  }
1214  auto resultType = cast<gpu::MMAMatrixType>(matrixOperands[0].getType());
1215  if (opType == gpu::MMAElementwiseOp::EXTF) {
1216  // The floating point extension case has a different result type.
1217  auto vectorType = cast<VectorType>(op->getResultTypes()[0]);
1218  resultType = gpu::MMAMatrixType::get(resultType.getShape(),
1219  vectorType.getElementType(),
1220  resultType.getOperand());
1221  }
1222 
1223  Value newOp = rewriter.create<gpu::SubgroupMmaElementwiseOp>(
1224  op->getLoc(), resultType, matrixOperands, opType);
1225  valueMapping[op->getResult(0)] = newOp;
1226  return success();
1227 }
1228 
1230  bool useNvGpu) {
1231  if (!useNvGpu) {
1232  patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>(
1233  patterns.getContext());
1234  return;
1235  }
1237  patterns.add<CombineTransferReadOpTranspose>(patterns.getContext());
1238 }
1239 
1241  Operation *rootOp) {
1242  SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/false);
1243  llvm::DenseMap<Value, Value> valueMapping;
1244 
1245  auto globalRes = LogicalResult::success();
1246  for (Operation *op : ops) {
1247  LLVM_DEBUG(DBGS() << "Process op: " << *op << "\n");
1248  // Apparently callers do not want to early exit on failure here.
1249  auto res = LogicalResult::success();
1250  if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
1251  res = convertTransferReadOp(rewriter, transferRead, valueMapping);
1252  } else if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) {
1253  res = convertTransferWriteOp(rewriter, transferWrite, valueMapping);
1254  } else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) {
1255  res = convertContractOp(rewriter, contractOp, valueMapping);
1256  } else if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
1257  res = convertConstantOp(rewriter, constantOp, valueMapping);
1258  } else if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) {
1259  res = convertBroadcastOp(rewriter, broadcastOp, valueMapping);
1260  } else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
1261  res = convertForOp(rewriter, forOp, valueMapping);
1262  } else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
1263  res = convertYieldOp(rewriter, yieldOp, valueMapping);
1264  } else if (auto elementwiseType = convertElementwiseOpToMMA(op)) {
1265  res = convertElementwiseOp(rewriter, op, *elementwiseType, valueMapping);
1266  }
1267  if (failed(res))
1268  globalRes = failure();
1269  }
1270  return globalRes;
1271 }
1272 
1274  Operation *rootOp) {
1275  SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/true);
1276  llvm::DenseMap<Value, Value> valueMapping;
1277  for (Operation *op : ops) {
1279  .Case([&](vector::TransferReadOp transferReadOp) {
1280  return convertTransferReadToLoads(rewriter, transferReadOp,
1281  valueMapping);
1282  })
1283  .Case([&](vector::TransferWriteOp transferWriteOp) {
1284  return convertTransferWriteToStores(rewriter, transferWriteOp,
1285  valueMapping);
1286  })
1287  .Case([&](vector::ExtractStridedSliceOp extractStridedSliceOp) {
1288  return convertExtractStridedSlice(rewriter, extractStridedSliceOp,
1289  valueMapping);
1290  })
1291  .Case([&](vector::ContractionOp contractionOp) {
1292  return convertContractOpToMmaSync(rewriter, contractionOp,
1293  valueMapping);
1294  })
1295  .Case([&](scf::ForOp forOp) {
1296  return convertForOp(rewriter, forOp, valueMapping);
1297  })
1298  .Case([&](scf::YieldOp yieldOp) {
1299  return convertYieldOp(rewriter, yieldOp, valueMapping);
1300  })
1301  .Case([&](arith::ConstantOp constOp) {
1302  return convertConstantOpMmaSync(rewriter, constOp, valueMapping);
1303  })
1304  .Default([&](Operation *op) {
1305  return op->emitError() << "unhandled vector to mma type: " << *op;
1306  })
1307  .failed()) {
1308  return op->emitOpError()
1309  << "failed to convert op during vector-to-nvgpu conversion";
1310  }
1311  }
1312  return success();
1313 }
1314 
1315 namespace {
1316 
1317 struct ConvertVectorToGPUPass
1318  : public impl::ConvertVectorToGPUBase<ConvertVectorToGPUPass> {
1319 
1320  explicit ConvertVectorToGPUPass(bool useNvGpu_) {
1321  useNvGpu.setValue(useNvGpu_);
1322  }
1323 
1324  void runOnOperation() override {
1326  populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue());
1327  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
1328  return signalPassFailure();
1329 
1330  IRRewriter rewriter(&getContext());
1331  if (useNvGpu) {
1332  if (failed(
1333  convertVectorToNVVMCompatibleMMASync(rewriter, getOperation())))
1334  return signalPassFailure();
1335  return;
1336  }
1337  (void)convertVectorToMMAOps(rewriter, getOperation());
1338  }
1339 };
1340 
1341 } // namespace
1342 
1343 std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass(bool useNvGpu) {
1344  return std::make_unique<ConvertVectorToGPUPass>(useNvGpu);
1345 }
static MLIRContext * getContext(OpFoldResult val)
#define SUBI(lhs, rhs)
Definition: LoopEmitter.cpp:37
#define MULI(lhs, rhs)
Definition: LoopEmitter.cpp:38
#define ADDI(lhs, rhs)
Definition: LoopEmitter.cpp:35
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
static LogicalResult convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertContractOp(RewriterBase &rewriter, vector::ContractionOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool isSharedMemory(MemRefType type)
Return true if this is a shared memory memref type.
static VectorType getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo &regInfo)
Returns the vector type which represents a matrix fragment.
static const char * inferFragType(Operation *op)
static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp)
static bool supportsMMaMatrixType(Operation *op, bool useNvGpu)
static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp)
Return true if the constant is a splat to a 2D vector so that it can be converted to a MMA constant m...
static bool contractSupportsMMAMatrixType(vector::ContractionOp contract, bool useNvGpu)
Definition: VectorToGPU.cpp:73
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
static bool isTransposeMatrixLoadMap(AffineMap permutationMap)
static SetVector< Operation * > getSliceContract(Operation *op, const BackwardSliceOptions &backwardSliceOptions, const ForwardSliceOptions &forwardSliceOptions)
Return an unsorted slice handling scf.for region differently than getSlice.
static LogicalResult convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp)
Return true if this integer extend op can be folded into a contract op.
static LogicalResult convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
Converts a vector.transfer_read operation directly to either a vector.load or a nvgpu....
static LogicalResult convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
static LogicalResult creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
static FailureOr< bool > isTransposed(vector::TransferReadOp op)
Check if the loaded matrix operand requires transposed.
static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, scf::ForOp loop, ValueRange newInitArgs)
static bool transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp)
static std::optional< gpu::MMAElementwiseOp > convertElementwiseOpToMMA(Operation *op)
Return the MMA elementwise enum associated with op if it is supported.
static LogicalResult convertElementwiseOp(RewriterBase &rewriter, Operation *op, gpu::MMAElementwiseOp opType, llvm::DenseMap< Value, Value > &valueMapping)
Convert an elementwise op to the equivalent elementwise op on MMA matrix.
static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp)
static bool extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op)
Returns true if the extract strided slice op is supported with mma.sync path.
static LogicalResult convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static bool elementwiseSupportsMMAMatrixType(Operation *op)
Return true if the op is supported as elementwise op on MMAMatrix type.
static SetVector< Operation * > getOpToConvert(mlir::Operation *op, bool useNvGpu)
static LogicalResult convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static LogicalResult convertYieldOp(RewriterBase &rewriter, scf::YieldOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op, llvm::DenseMap< Value, Value > &valueMapping)
#define DBGS()
Definition: VectorToGPU.cpp:36
static std::optional< int64_t > getStaticallyKnownRowStride(ShapedType type)
static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp, AffineMap offsetMap, ArrayRef< Value > dimValues, SmallVector< Value, 4 > &indices)
For a vector TransferOpType xferOp, an empty indices vector, and an AffineMap representing offsets to...
Definition: VectorToGPU.cpp:53
static LogicalResult convertExtractStridedSlice(RewriterBase &rewriter, vector::ExtractStridedSliceOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp)
Return true if this is a broadcast from scalar to a 2D vector.
static LogicalResult createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
MLIRContext * getContext() const
Definition: AffineMap.cpp:339
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
Definition: AffineMap.cpp:151
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
Definition: AffineMap.cpp:390
unsigned getNumResults() const
Definition: AffineMap.cpp:398
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:407
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:260
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:552
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:308
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
UnitAttr getUnitAttr()
Definition: Builders.cpp:93
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:367
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:319
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:359
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:276
IndexType getIndexType()
Definition: Builders.cpp:50
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:313
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:729
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:849
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:873
user_iterator user_begin()
Definition: Operation.h:869
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:748
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:681
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:601
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition: GPUDialect.h:131
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Definition: GPUDialect.cpp:125
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1280
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
int64_t inferTileWidthInBits(const WarpMatrixInfo &type)
Returns the number of bits in a single tile row.
Definition: MMAUtils.cpp:87
FailureOr< vector::ContractionOp > getUserContract(Operation *op)
Returns the first user of the op that is vector.contract.
Definition: MMAUtils.cpp:50
FailureOr< AffineMap > getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc, const WarpMatrixInfo &fragmentType)
Returns an AffineMap which maps a two dimensions representing (laneId, logicalValueId) and returns tw...
Definition: MMAUtils.cpp:169
FailureOr< WarpMatrixInfo > getWarpMatrixInfo(Operation *op)
If op is a vector.transfer_write, return the WarpMatrixInfo for the vector operand.
Definition: MMAUtils.cpp:58
FailureOr< AffineMap > getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc, const LdMatrixParams &params)
Returns an AffineMap which maps a single dimension representing the laneId to two results representin...
Definition: MMAUtils.cpp:234
FailureOr< LdMatrixParams > getLdMatrixParams(const WarpMatrixInfo &type, bool transpose)
Given type that contains info for a warp-matrix operand and whether or not the load is a transposed l...
Definition: MMAUtils.cpp:205
FailureOr< FragmentElementInfo > getMmaSyncRegisterType(const WarpMatrixInfo &type)
Returns a FragmentElementInfo struct describing the register types for the given matrix fragment type...
Definition: MMAUtils.cpp:100
bool canLowerToWarpMatrixOperation(vector::TransferReadOp op)
Returns whether the vector.transfer_read instruction can be interpreted as a warp-level cooperative m...
Definition: MMAUtils.cpp:272
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Definition: VectorOps.h:153
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Definition: VectorOps.h:148
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction wit...
Include the generated interface declarations.
void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, bool useNvGpu=false)
Patterns to transform vector ops into a canonical form to convert to MMA matrix operations.
LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
LogicalResult convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter, Operation *rootOp)
Convert vector ops ops nested under rootOp to vector and GPU operaitons compatible with the nvvm....
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:645
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::unique_ptr< Pass > createConvertVectorToGPUPass(bool useNvGpu=false)
Convert from vector to GPU ops.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:621
LogicalResult convertVectorToMMAOps(RewriterBase &rewriter, Operation *rootOp)
Convert vector ops to MMA matrix operations nested under rootOp.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
This trait tags element-wise ops on vectors or tensors.
TransitiveFilter filter
Definition: SliceAnalysis.h:29
Specifies information about the registers which compose a matrix fragment according to the PTX docume...
Definition: MMAUtils.h:52