MLIR  19.0.0git
VectorToGPU.cpp
Go to the documentation of this file.
1 //===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of vector operations to GPU dialect ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include <type_traits>
16 
30 #include "mlir/IR/Builders.h"
31 #include "mlir/IR/BuiltinOps.h"
32 #include "mlir/IR/Region.h"
33 #include "mlir/Pass/Pass.h"
35 #include "mlir/Transforms/Passes.h"
36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/TypeSwitch.h"
38 
39 #define DEBUG_TYPE "vector-to-gpu"
40 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
41 #define DBGSNL() (llvm::dbgs() << "\n")
42 
43 namespace mlir {
44 #define GEN_PASS_DEF_CONVERTVECTORTOGPU
45 #include "mlir/Conversion/Passes.h.inc"
46 } // namespace mlir
47 
48 using namespace mlir;
49 
50 /// For a vector TransferOpType `xferOp`, an empty `indices` vector, and an
51 /// AffineMap representing offsets to apply to indices, the function fills
52 /// `indices` with the original indices plus the offsets. The offsets are
53 /// applied by taking into account the permutation map of the transfer op. If
54 /// the `offsetMap` has dimension placeholders, those should be provided in
55 /// `dimValues`.
56 template <typename TransferOpType>
57 static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp,
58  AffineMap offsetMap, ArrayRef<Value> dimValues,
59  SmallVector<Value, 4> &indices) {
60  indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end());
61  Location loc = xferOp.getLoc();
62  unsigned offsetsIdx = 0;
63  for (auto expr : xferOp.getPermutationMap().getResults()) {
64  if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
65  Value prevIdx = indices[dim.getPosition()];
66  SmallVector<OpFoldResult, 3> dims(dimValues.begin(), dimValues.end());
67  dims.push_back(prevIdx);
68  AffineExpr d0 = rewriter.getAffineDimExpr(offsetMap.getNumDims());
69  indices[dim.getPosition()] = affine::makeComposedAffineApply(
70  rewriter, loc, d0 + offsetMap.getResult(offsetsIdx++), dims);
71  continue;
72  }
73  }
74 }
75 
76 // Return true if the contract op can be convert to MMA matmul.
77 static bool contractSupportsMMAMatrixType(vector::ContractionOp contract,
78  bool useNvGpu) {
79  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
80  auto infer = [&](MapList m) {
81  return AffineMap::inferFromExprList(m, contract.getContext());
82  };
83  AffineExpr m, n, k;
84  bindDims(contract.getContext(), m, n, k);
85  auto iteratorTypes = contract.getIteratorTypes().getValue();
86  if (!(vector::isParallelIterator(iteratorTypes[0]) &&
87  vector::isParallelIterator(iteratorTypes[1]) &&
88  vector::isReductionIterator(iteratorTypes[2])))
89  return false;
90 
91  // The contract needs to represent a matmul to be able to convert to
92  // MMAMatrix matmul.
93  if (!useNvGpu &&
94  contract.getIndexingMapsArray() != infer({{m, k}, {k, n}, {m, n}}))
95  return false;
96  if (useNvGpu &&
97  contract.getIndexingMapsArray() != infer({{m, k}, {n, k}, {m, n}}))
98  return false;
99 
100  return true;
101 }
102 
103 // Return true if the given map represents a transposed matrix load,
104 // i.e. (d0, d1, ...) -> (dn-1, dn-2).
105 static bool isTransposeMatrixLoadMap(AffineMap permutationMap) {
106  MLIRContext *ctx = permutationMap.getContext();
107  // Local OpBuilder is fine here, we just build attributes.
108  OpBuilder b(ctx);
109  auto nDim = permutationMap.getNumDims();
110  AffineExpr zero = b.getAffineConstantExpr(0);
111  if (nDim < 2) {
112  // Support transposed+broadcasted cases: affine_map<(d0) -> (d0, 0)>.
113  AffineExpr dim0 = b.getAffineDimExpr(0);
114  return permutationMap == AffineMap::get(1, 0, {dim0, zero}, ctx);
115  }
116 
117  AffineExpr innerDim = b.getAffineDimExpr(nDim - 1);
118  AffineExpr outerDim = b.getAffineDimExpr(nDim - 2);
119  // Support both transposed and transposed+broadcasted cases.
120  return permutationMap == AffineMap::get(nDim, 0, {innerDim, outerDim}, ctx) ||
121  permutationMap == AffineMap::get(nDim, 0, {innerDim, zero}, ctx);
122 }
123 
124 // Return the stide for the second-to-last dimension of |type| if it is a memref
125 // and has a constant stride.
126 static std::optional<int64_t> getStaticallyKnownRowStride(ShapedType type) {
127  auto memrefType = dyn_cast<MemRefType>(type);
128  if (!memrefType)
129  return false;
130  // If the memref is 0 or 1D the horizontal stride is 0.
131  if (memrefType.getRank() < 2)
132  return 0;
133  int64_t offset = 0;
134  SmallVector<int64_t, 2> strides;
135  if (failed(getStridesAndOffset(memrefType, strides, offset)) ||
136  strides.back() != 1)
137  return std::nullopt;
138  int64_t stride = strides[strides.size() - 2];
139  if (stride == ShapedType::kDynamic)
140  return std::nullopt;
141  return stride;
142 }
143 
144 // Return true if the transfer op can be converted to a MMA matrix load.
145 static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
146  if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
147  readOp.getVectorType().getRank() != 2)
148  return false;
149  if (!getStaticallyKnownRowStride(readOp.getShapedType()))
150  return false;
151 
152  // Only allow integer types if the signedness can be inferred.
153  if (readOp.getVectorType().getElementType().isInteger(8))
154  if (!readOp->hasOneUse() || (!isa<arith::ExtSIOp>(*readOp->user_begin()) &&
155  !isa<arith::ExtUIOp>(*readOp->user_begin())))
156  return false;
157 
158  AffineMap map = readOp.getPermutationMap();
159  MLIRContext *ctx = readOp.getContext();
160  AffineExpr innerDim = getAffineDimExpr(map.getNumDims() - 1, ctx);
161  AffineExpr zero = getAffineConstantExpr(0, ctx);
162  auto broadcastInnerDim =
163  AffineMap::get(map.getNumDims(), 0, {zero, innerDim}, ctx);
164  return map.isMinorIdentity() || map == broadcastInnerDim ||
166 }
167 
168 // Return true if the transfer op can be converted to a MMA matrix store.
169 static bool
170 transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
171  // TODO: support 0-d corner case.
172  if (writeOp.getTransferRank() == 0)
173  return false;
174 
175  if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() ||
176  writeOp.getVectorType().getRank() != 2)
177  return false;
178  if (!getStaticallyKnownRowStride(writeOp.getShapedType()))
179  return false;
180  // TODO: Support transpose once it is added to GPU dialect ops.
181  if (!writeOp.getPermutationMap().isMinorIdentity())
182  return false;
183  return true;
184 }
185 
186 /// Return true if the constant is a splat to a 2D vector so that it can be
187 /// converted to a MMA constant matrix op.
188 static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) {
189  auto vecType = dyn_cast<VectorType>(constantOp.getType());
190  if (!vecType || vecType.getRank() != 2)
191  return false;
192  return isa<SplatElementsAttr>(constantOp.getValue());
193 }
194 
195 /// Return true if this is a broadcast from scalar to a 2D vector.
196 static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {
197  return broadcastOp.getResultVectorType().getRank() == 2;
198 }
199 
200 /// Return true if this integer extend op can be folded into a contract op.
201 template <typename ExtOpTy>
202 static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) {
203  if (!isa<vector::TransferReadOp>(extOp.getOperand().getDefiningOp()))
204  return false;
205  return llvm::all_of(extOp->getUsers(), llvm::IsaPred<vector::ContractionOp>);
206 }
207 
208 static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp) { return true; }
209 
210 /// Return the MMA elementwise enum associated with `op` if it is supported.
211 /// Return `std::nullopt` otherwise.
212 static std::optional<gpu::MMAElementwiseOp>
214  if (isa<arith::AddFOp>(op))
215  return gpu::MMAElementwiseOp::ADDF;
216  if (isa<arith::MulFOp>(op))
217  return gpu::MMAElementwiseOp::MULF;
218  if (isa<arith::SubFOp>(op))
219  return gpu::MMAElementwiseOp::SUBF;
220  if (isa<arith::MaximumFOp>(op))
221  return gpu::MMAElementwiseOp::MAXF;
222  if (isa<arith::MinimumFOp>(op))
223  return gpu::MMAElementwiseOp::MINF;
224  if (isa<arith::DivFOp>(op))
225  return gpu::MMAElementwiseOp::DIVF;
226  if (isa<arith::AddIOp>(op))
228  if (isa<arith::MulIOp>(op))
230  if (isa<arith::SubIOp>(op))
232  if (isa<arith::DivSIOp>(op))
233  return gpu::MMAElementwiseOp::DIVS;
234  if (isa<arith::DivUIOp>(op))
235  return gpu::MMAElementwiseOp::DIVU;
236  if (isa<arith::NegFOp>(op))
237  return gpu::MMAElementwiseOp::NEGATEF;
238  if (isa<arith::ExtFOp>(op))
239  return gpu::MMAElementwiseOp::EXTF;
240  return std::nullopt;
241 }
242 
243 /// Return true if the op is supported as elementwise op on MMAMatrix type.
245  return convertElementwiseOpToMMA(op).has_value();
246 }
247 
248 /// Returns true if the extract strided slice op is supported with `mma.sync`
249 /// path.
250 static bool
251 extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op) {
252 
253  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
255  if (failed(warpMatrixInfo))
256  return false;
257 
258  FailureOr<vector::ContractionOp> contractOp = nvgpu::getUserContract(op);
259  if (failed(contractOp))
260  return false;
261 
262  // Handle vector.extract_strided_slice on registers containing
263  // matrixB and matrixC operands. vector.extract_strided_slice op
264  // is not supported on registers containing matrixA operands.
265  if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B)
266  return (cast<VectorType>(op->getResult(0).getType()) ==
267  cast<VectorType>((*contractOp).getRhs().getType()));
268  if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::C)
269  return (cast<VectorType>(op->getResult(0).getType()) ==
270  cast<VectorType>((*contractOp).getAcc().getType()));
271 
272  return false;
273 }
274 
275 static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
276  if (isa<scf::ForOp, scf::YieldOp>(op))
277  return true;
278  if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
279  return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferRead)
280  : transferReadSupportsMMAMatrixType(transferRead);
281  if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
282  return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferWrite)
283  : transferWriteSupportsMMAMatrixType(transferWrite);
284  if (auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op))
285  return useNvGpu &&
286  extractStridedSliceSupportsMMAMatrixType(extractStridedSlice);
287  if (auto contract = dyn_cast<vector::ContractionOp>(op))
288  return contractSupportsMMAMatrixType(contract, useNvGpu);
289  if (auto constant = dyn_cast<arith::ConstantOp>(op))
290  return constantSupportsMMAMatrixType(constant);
291  if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
293  if (auto signedExtend = dyn_cast<arith::ExtSIOp>(op))
294  return integerExtendSupportsMMAMatrixType<arith::ExtSIOp>(signedExtend);
295  if (auto unsignedExtend = dyn_cast<arith::ExtUIOp>(op))
296  return integerExtendSupportsMMAMatrixType<arith::ExtUIOp>(unsignedExtend);
297  if (auto fpExtend = dyn_cast<arith::ExtFOp>(op))
298  return fpExtendSupportsMMAMatrixType(fpExtend);
300 }
301 
302 /// Return an unsorted slice handling scf.for region differently than
303 /// `getSlice`. In scf.for we only want to include as part of the slice elements
304 /// that are part of the use/def chain.
307  const BackwardSliceOptions &backwardSliceOptions,
308  const ForwardSliceOptions &forwardSliceOptions) {
310  slice.insert(op);
311  unsigned currentIndex = 0;
312  SetVector<Operation *> backwardSlice;
313  SetVector<Operation *> forwardSlice;
314  while (currentIndex != slice.size()) {
315  auto *currentOp = (slice)[currentIndex];
316  // Compute and insert the backwardSlice starting from currentOp.
317  backwardSlice.clear();
318  getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
319  slice.insert(backwardSlice.begin(), backwardSlice.end());
320 
321  // Compute and insert the forwardSlice starting from currentOp.
322  forwardSlice.clear();
323  // Special case for ForOp, we don't want to include the whole region but
324  // only the value using the region arguments.
325  // TODO: We should refine this to only care about the region arguments being
326  // converted to matrix type.
327  if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) {
328  for (Value forOpResult : forOp.getResults())
329  getForwardSlice(forOpResult, &forwardSlice, forwardSliceOptions);
330  for (BlockArgument &arg : forOp.getRegionIterArgs())
331  getForwardSlice(arg, &forwardSlice, forwardSliceOptions);
332  } else {
333  getForwardSlice(currentOp, &forwardSlice, forwardSliceOptions);
334  }
335  slice.insert(forwardSlice.begin(), forwardSlice.end());
336  ++currentIndex;
337  }
338  return slice;
339 }
340 
341 // Analyze slice of operations based on convert op to figure out if the whole
342 // slice can be converted to MMA operations.
344  bool useNvGpu) {
345  auto hasVectorDest = [](Operation *op) {
346  return llvm::any_of(op->getResultTypes(), llvm::IsaPred<VectorType>);
347  };
348  BackwardSliceOptions backwardSliceOptions;
349  backwardSliceOptions.filter = hasVectorDest;
350 
351  auto hasVectorSrc = [](Operation *op) {
352  return llvm::any_of(op->getOperandTypes(), llvm::IsaPred<VectorType>);
353  };
354  ForwardSliceOptions forwardSliceOptions;
355  forwardSliceOptions.filter = hasVectorSrc;
356 
357  SetVector<Operation *> opToConvert;
358  op->walk([&](vector::ContractionOp contract) {
359  if (opToConvert.contains(contract.getOperation()))
360  return;
361  SetVector<Operation *> dependentOps =
362  getSliceContract(contract, backwardSliceOptions, forwardSliceOptions);
363  // If any instruction cannot use MMA matrix type drop the whole
364  // chain. MMA matrix are stored in an opaque type so they cannot be used
365  // by all operations.
366  if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) {
367  if (!supportsMMaMatrixType(op, useNvGpu)) {
368  LLVM_DEBUG(DBGS() << "cannot convert op: " << *op << "\n");
369  return true;
370  }
371  return false;
372  }))
373  return;
374 
375  opToConvert.insert(dependentOps.begin(), dependentOps.end());
376  });
377  // Sort the operations so that we can convert them in topological order.
378  return topologicalSort(opToConvert);
379 }
380 
381 namespace {
382 // Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted
383 // to MMA matmul.
384 struct PrepareContractToGPUMMA
385  : public OpRewritePattern<vector::ContractionOp> {
387 
388  LogicalResult matchAndRewrite(vector::ContractionOp op,
389  PatternRewriter &rewriter) const override {
390  Location loc = op.getLoc();
391  Value lhs = op.getLhs(), rhs = op.getRhs(), res = op.getAcc();
392 
393  // Set up the parallel/reduction structure in right form.
394  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
395  auto infer = [&](MapList m) {
397  };
398  AffineExpr m, n, k;
399  bindDims(rewriter.getContext(), m, n, k);
400  static constexpr std::array<int64_t, 2> perm = {1, 0};
401  auto iteratorTypes = op.getIteratorTypes().getValue();
402  SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
403  if (!(vector::isParallelIterator(iteratorTypes[0]) &&
404  vector::isParallelIterator(iteratorTypes[1]) &&
405  vector::isReductionIterator(iteratorTypes[2])))
406  return rewriter.notifyMatchFailure(op, "not a gemm contraction");
407  //
408  // Two outer parallel, one inner reduction (matmat flavor).
409  //
410  // This is the classical row-major matmul, nothing to do.
411  if (maps == infer({{m, k}, {k, n}, {m, n}}))
412  return rewriter.notifyMatchFailure(op, "contraction already prepared");
413  if (maps == infer({{m, k}, {n, k}, {m, n}})) {
414  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
415  } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
416  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
417  } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
418  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
419  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
420  } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
421  std::swap(rhs, lhs);
422  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
423  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
424  } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
425  std::swap(rhs, lhs);
426  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
427  } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
428  std::swap(lhs, rhs);
429  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
430  } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
431  std::swap(lhs, rhs);
432  } else {
433  // TODO: llvm_unreachable ?
434  return rewriter.notifyMatchFailure(op, "unexpected contraction case");
435  }
436  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
437  op, lhs, rhs, res,
438  rewriter.getAffineMapArrayAttr(infer({{m, k}, {k, n}, {m, n}})),
439  op.getIteratorTypes());
440  return success();
441  }
442 };
443 
444 // Fold transpose op into the transfer read op. NVGPU mma.sync op only supports
445 // row-, column-, and row-major layout for matrixA, matrixB, and matrixC,
446 // respectively. We can fold the transpose operation when loading the data from
447 // Shared Memory to registers.
448 struct CombineTransferReadOpTranspose final
449  : public OpRewritePattern<vector::TransposeOp> {
451 
452  LogicalResult matchAndRewrite(vector::TransposeOp op,
453  PatternRewriter &rewriter) const override {
454  // Look through integer extend ops.
455  Value source = op.getVector();
456  Type resultType = op.getType();
457  Operation *extOp;
458  if ((extOp = source.getDefiningOp<arith::ExtSIOp>()) ||
459  (extOp = source.getDefiningOp<arith::ExtUIOp>()) ||
460  (extOp = source.getDefiningOp<arith::ExtFOp>())) {
461  source = extOp->getOperand(0);
462  resultType =
463  VectorType::get(cast<VectorType>(resultType).getShape(),
464  cast<VectorType>(source.getType()).getElementType());
465  }
466 
467  auto transferReadOp = source.getDefiningOp<vector::TransferReadOp>();
468  if (!transferReadOp)
469  return rewriter.notifyMatchFailure(op, "no transfer read");
470 
471  // TODO: support 0-d corner case.
472  if (transferReadOp.getTransferRank() == 0)
473  return rewriter.notifyMatchFailure(op, "0-D transfer read");
474 
475  if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim())
476  return rewriter.notifyMatchFailure(op, "not inbounds transfer read");
477 
478  AffineMap permutationMap =
479  AffineMap::getPermutationMap(op.getPermutation(), op.getContext());
480  AffineMap newMap =
481  permutationMap.compose(transferReadOp.getPermutationMap());
482 
483  auto loc = op.getLoc();
484  Value result =
485  rewriter
486  .create<vector::TransferReadOp>(
487  loc, resultType, transferReadOp.getSource(),
488  transferReadOp.getIndices(), AffineMapAttr::get(newMap),
489  transferReadOp.getPadding(), transferReadOp.getMask(),
490  transferReadOp.getInBoundsAttr())
491  .getResult();
492 
493  // Fuse through the integer extend op.
494  if (extOp) {
495  if (isa<arith::ExtSIOp>(extOp))
496  result = rewriter.create<arith::ExtSIOp>(loc, op.getType(), result)
497  .getResult();
498  else if (isa<arith::ExtUIOp>(extOp))
499  result = rewriter.create<arith::ExtUIOp>(loc, op.getType(), result)
500  .getResult();
501  else
502  result = rewriter.create<arith::ExtFOp>(loc, op.getType(), result)
503  .getResult();
504  }
505 
506  rewriter.replaceOp(op, result);
507  return success();
508  }
509 };
510 
511 } // namespace
512 
513 // MMA types have different layout based on how they are used in matmul ops.
514 // Figure the right layout to use by looking at op uses.
515 // TODO: Change the GPU dialect to abstract the layout at the this level and
516 // only care about it during lowering to NVVM.
517 static const char *inferFragType(Operation *op) {
518  // We can have arith.ext ops before reaching contract ops. See through them
519  // and other kinds of elementwise ops.
520  if (op->hasOneUse()) {
521  Operation *userOp = *op->user_begin();
522  if (userOp->hasTrait<OpTrait::Elementwise>())
523  return inferFragType(userOp);
524  }
525 
526  for (Operation *users : op->getUsers()) {
527  auto contract = dyn_cast<vector::ContractionOp>(users);
528  if (!contract)
529  continue;
530  assert(op->getNumResults() == 1);
531  if (contract.getLhs() == op->getResult(0))
532  return "AOp";
533  if (contract.getRhs() == op->getResult(0))
534  return "BOp";
535  }
536  return "COp";
537 }
538 
539 static LogicalResult
540 convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
541  llvm::DenseMap<Value, Value> &valueMapping) {
542  OpBuilder::InsertionGuard g(rewriter);
543  rewriter.setInsertionPoint(op);
544 
545  assert(op.getTransferRank() > 0 && "unexpected 0-d transfer");
547  "expected convertible operation");
548 
549  std::optional<int64_t> stride =
550  getStaticallyKnownRowStride(op.getShapedType());
551  if (!stride.has_value()) {
552  LLVM_DEBUG(DBGS() << "no stride\n");
553  return rewriter.notifyMatchFailure(op, "no stride");
554  }
555 
556  AffineMap map = op.getPermutationMap();
557  bool isTranspose = isTransposeMatrixLoadMap(map);
558 
559  // Handle broadcast by setting the stride to 0.
560  if (auto cstExpr = dyn_cast<AffineConstantExpr>(map.getResult(isTranspose))) {
561  assert(cstExpr.getValue() == 0);
562  stride = 0;
563  }
564 
565  Value mappingResult = op.getResult();
566  auto elType = op.getVectorType().getElementType();
567  const char *fragType = inferFragType(op);
568  if (op->hasOneUse()) {
569  auto *user = *op->user_begin();
570  // Infer the signedness of the mma type from the integer extend.
571  if (isa<arith::ExtSIOp, arith::ExtUIOp>(user)) {
572  elType = IntegerType::get(
573  op.getContext(), cast<IntegerType>(elType).getWidth(),
574  isa<arith::ExtSIOp>(user) ? IntegerType::Signed
575  : IntegerType::Unsigned);
576  mappingResult = user->getResult(0);
577  }
578  }
579  gpu::MMAMatrixType type =
580  gpu::MMAMatrixType::get(op.getVectorType().getShape(), elType, fragType);
581  Value load = rewriter.create<gpu::SubgroupMmaLoadMatrixOp>(
582  op.getLoc(), type, op.getSource(), op.getIndices(),
583  rewriter.getIndexAttr(*stride),
584  isTranspose ? rewriter.getUnitAttr() : UnitAttr());
585  valueMapping[mappingResult] = load;
586 
587  LLVM_DEBUG(DBGS() << "transfer read to: " << load << "\n");
588  return success();
589 }
590 
591 static LogicalResult
592 convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
593  llvm::DenseMap<Value, Value> &valueMapping) {
594  OpBuilder::InsertionGuard g(rewriter);
595  rewriter.setInsertionPoint(op);
596 
598  std::optional<int64_t> stride =
599  getStaticallyKnownRowStride(op.getShapedType());
600  if (!stride.has_value()) {
601  LLVM_DEBUG(DBGS() << "no stride\n");
602  return rewriter.notifyMatchFailure(op, "no stride");
603  }
604 
605  auto it = valueMapping.find(op.getVector());
606  if (it == valueMapping.end()) {
607  LLVM_DEBUG(DBGS() << "no mapping\n");
608  return rewriter.notifyMatchFailure(op, "no mapping");
609  }
610 
611  Value matrix = it->second;
612  auto store = rewriter.create<gpu::SubgroupMmaStoreMatrixOp>(
613  op.getLoc(), matrix, op.getSource(), op.getIndices(),
614  rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr());
615  (void)store;
616 
617  LLVM_DEBUG(DBGS() << "transfer write to: " << store << "\n");
618 
619  LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
620  rewriter.eraseOp(op);
621  return success();
622 }
623 
624 /// Returns the vector type which represents a matrix fragment.
625 static VectorType
628  regInfo.elementsPerRegister};
629  Type elType = regInfo.registerLLVMType;
630  if (auto vecType = dyn_cast<VectorType>(elType))
631  elType = vecType.getElementType();
632  return VectorType::get(shape, elType);
633 }
634 
635 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
636 static LogicalResult
637 convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op,
638  llvm::DenseMap<Value, Value> &valueMapping) {
639  OpBuilder::InsertionGuard g(rewriter);
640  rewriter.setInsertionPoint(op);
641 
642  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
644  if (failed(warpMatrixInfo)) {
645  LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
646  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
647  }
648 
649  FailureOr<nvgpu::FragmentElementInfo> regInfo =
650  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
651  if (failed(regInfo)) {
652  LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
653  return rewriter.notifyMatchFailure(op, "not mma sync reg info");
654  }
655 
656  VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
657  auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
658  if (!dense) {
659  LLVM_DEBUG(DBGS() << "not a splat\n");
660  return rewriter.notifyMatchFailure(op, "not a splat");
661  }
662 
663  Value result = rewriter.create<arith::ConstantOp>(
664  op.getLoc(), vectorType,
665  DenseElementsAttr::get(vectorType, dense.getSplatValue<Attribute>()));
666  valueMapping[op.getResult()] = result;
667  return success();
668 }
669 
670 /// Check if the loaded matrix operand requires transposed.
671 /// Transposed Map Example:
672 /// Example 1 : (..., d0, d1) -> (d1 * 1, d0 * 2)
673 /// Example 2 : (d0, d1, d2, d3) -> (d3, d2)
674 /// The code below checks if the output 2D is transposed using a generalized
675 /// version : (d0, d1, dn, ..., dm, ...) -> (dm, dn)
676 /// Returns : true; if m > n, false o.w.
677 static FailureOr<bool> isTransposed(vector::TransferReadOp op) {
678  mlir::AffineMap map = op.getPermutationMap();
679 
680  if (map.getNumResults() != 2) {
681  LLVM_DEBUG(DBGS() << "Failed because the result of `vector.transfer_read` "
682  "is not a 2d operand\n");
683  return failure();
684  }
685 
686  // Output 2D matrix dimensions in the order of d0, d1.
687  mlir::AffineExpr dM = map.getResult(0);
688  mlir::AffineExpr dN = map.getResult(1);
689 
690  // Find the position of these expressions in the input.
691  auto exprM = dyn_cast<AffineDimExpr>(dM);
692  auto exprN = dyn_cast<AffineDimExpr>(dN);
693 
694  if (!exprM || !exprN) {
695  LLVM_DEBUG(DBGS() << "Failed because expressions are not affine dim "
696  "expressions, then transpose cannot be determined.\n");
697  return failure();
698  }
699 
700  return exprM.getPosition() > exprN.getPosition();
701 }
702 
703 static LogicalResult
704 creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
705  llvm::DenseMap<Value, Value> &valueMapping) {
706  OpBuilder::InsertionGuard g(rewriter);
707  rewriter.setInsertionPoint(op);
708  Location loc = op->getLoc();
709 
710  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
712  if (failed(warpMatrixInfo)) {
713  LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
714  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
715  }
716 
717  FailureOr<nvgpu::FragmentElementInfo> regInfo =
718  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
719  if (failed(regInfo)) {
720  LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
721  return rewriter.notifyMatchFailure(op, "not mma sync reg info");
722  }
723 
724  FailureOr<bool> transpose = isTransposed(op);
725  if (failed(transpose)) {
726  LLVM_DEBUG(DBGS() << "failed to determine the transpose\n");
727  return rewriter.notifyMatchFailure(
728  op, "Op should likely not be converted to a nvgpu.ldmatrix call.");
729  }
730 
731  FailureOr<nvgpu::LdMatrixParams> params =
732  nvgpu::getLdMatrixParams(*warpMatrixInfo, *transpose);
733 
734  if (failed(params)) {
735  LLVM_DEBUG(
736  DBGS()
737  << "failed to convert vector.transfer_read to ldmatrix. "
738  << "Op should likely not be converted to a nvgpu.ldmatrix call.\n");
739  return rewriter.notifyMatchFailure(
740  op, "failed to convert vector.transfer_read to ldmatrix; this op "
741  "likely should not be converted to a nvgpu.ldmatrix call.");
742  }
743 
744  // Adjust the load offset.
745  auto laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr);
746  FailureOr<AffineMap> offsets =
747  nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params);
748  if (failed(offsets)) {
749  LLVM_DEBUG(DBGS() << "no offsets\n");
750  return rewriter.notifyMatchFailure(op, "no offsets");
751  }
752 
753  VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
754 
755  SmallVector<Value, 4> indices;
756  getXferIndices<vector::TransferReadOp>(rewriter, op, *offsets, {laneId},
757  indices);
758 
759  nvgpu::LdMatrixOp newOp = rewriter.create<nvgpu::LdMatrixOp>(
760  loc, vectorType, op.getSource(), indices, *transpose, params->numTiles);
761  valueMapping[op] = newOp->getResult(0);
762  return success();
763 }
764 
765 static LogicalResult
766 createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
767  llvm::DenseMap<Value, Value> &valueMapping) {
768  OpBuilder::InsertionGuard g(rewriter);
769  rewriter.setInsertionPoint(op);
770 
771  Location loc = op.getLoc();
772  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
774  if (failed(warpMatrixInfo))
775  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
776  FailureOr<nvgpu::FragmentElementInfo> regInfo =
777  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
778  if (failed(regInfo)) {
779  return rewriter.notifyMatchFailure(
780  op, "Failed to deduce register fragment type during "
781  "conversion to distributed non-ldmatrix compatible load");
782  }
783 
784  Value laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr);
785  SmallVector<Value, 4> elements;
786 
787  // This is the individual element type.
788  Type loadedElType = regInfo->registerLLVMType;
789  VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
790 
791  Value fill = rewriter.create<arith::ConstantOp>(
792  op.getLoc(), vectorType.getElementType(),
793  rewriter.getZeroAttr(vectorType.getElementType()));
794  Value result =
795  rewriter.create<vector::SplatOp>(op.getLoc(), fill, vectorType);
796 
797  bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
798 
799  // If we are not transposing, then we can use vectorized loads. Otherwise, we
800  // must load each element individually.
801  if (!isTransposeLoad) {
802  if (!isa<VectorType>(loadedElType)) {
803  loadedElType = VectorType::get({1}, loadedElType);
804  }
805 
806  for (int i = 0; i < vectorType.getShape()[0]; i++) {
807  FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
808  rewriter, op.getLoc(), *warpMatrixInfo);
809  if (failed(coords))
810  return rewriter.notifyMatchFailure(op, "no coords");
811 
812  Value logicalValueId = rewriter.create<arith::ConstantOp>(
813  loc, rewriter.getIndexType(),
814  rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
815  SmallVector<Value, 4> newIndices;
816  getXferIndices<vector::TransferReadOp>(
817  rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
818 
819  Value el = rewriter.create<vector::LoadOp>(loc, loadedElType,
820  op.getSource(), newIndices);
821  result = rewriter.create<vector::InsertOp>(loc, el, result, i);
822  }
823  } else {
824  if (auto vecType = dyn_cast<VectorType>(loadedElType)) {
825  loadedElType = vecType.getElementType();
826  }
827  for (int i = 0; i < vectorType.getShape()[0]; i++) {
828  for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1];
829  innerIdx++) {
830 
831  Value logicalValueId = rewriter.create<arith::ConstantOp>(
832  loc, rewriter.getIndexType(),
833  rewriter.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx));
834  FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
835  rewriter, op.getLoc(), *warpMatrixInfo);
836  if (failed(coords))
837  return rewriter.notifyMatchFailure(op, "no coords");
838 
839  SmallVector<Value, 4> newIndices;
840  getXferIndices<vector::TransferReadOp>(
841  rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
842  Value el = rewriter.create<memref::LoadOp>(op.getLoc(), loadedElType,
843  op.getSource(), newIndices);
844  result = rewriter.create<vector::InsertOp>(
845  op.getLoc(), el, result, ArrayRef<int64_t>{i, innerIdx});
846  }
847  }
848  }
849 
850  valueMapping[op.getResult()] = result;
851  return success();
852 }
853 
854 /// Return true if this is a shared memory memref type.
855 static bool isSharedMemory(MemRefType type) {
856  auto addressSpace =
857  dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
858  return addressSpace &&
859  addressSpace.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
860 }
861 
862 /// Converts a `vector.transfer_read` operation directly to either a
863 /// `vector.load` or a `nvgpu.ldmatrix` operation. This function should only be
864 /// used when converting to `nvgpu.mma.sync` operations.
865 static LogicalResult
866 convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op,
867  llvm::DenseMap<Value, Value> &valueMapping) {
868  OpBuilder::InsertionGuard g(rewriter);
869  rewriter.setInsertionPoint(op);
870 
871  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
873  if (failed(warpMatrixInfo))
874  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
875 
876  bool isLdMatrixCompatible =
877  isSharedMemory(cast<MemRefType>(op.getSource().getType())) &&
878  nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128;
879 
880  VectorType vecTy = op.getVectorType();
881  int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth();
882 
883  // When we are transposing the B operand, ldmatrix will only work if we have
884  // at least 8 rows to read and the width to read for the transpose is 128
885  // bits.
886  if (!op.getPermutationMap().isMinorIdentity() &&
887  (bitWidth != 16 || vecTy.getDimSize(1) < 8 ||
888  vecTy.getDimSize(0) * bitWidth < 128))
889  isLdMatrixCompatible = false;
890 
891  if (!isLdMatrixCompatible)
892  return createNonLdMatrixLoads(rewriter, op, valueMapping);
893 
894  return creatLdMatrixCompatibleLoads(rewriter, op, valueMapping);
895 }
896 
897 static LogicalResult
898 convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
899  llvm::DenseMap<Value, Value> &valueMapping) {
900  OpBuilder::InsertionGuard g(rewriter);
901  rewriter.setInsertionPoint(op);
902 
903  Location loc = op->getLoc();
904  auto it = valueMapping.find(op.getVector());
905  if (it == valueMapping.end())
906  return rewriter.notifyMatchFailure(op, "no mapping");
907  Value matrix = it->second;
908 
909  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
911  if (failed(warpMatrixInfo))
912  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
913  FailureOr<nvgpu::FragmentElementInfo> regInfo =
914  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
915  if (failed(regInfo))
916  return rewriter.notifyMatchFailure(op, "not mma sync reg info");
917 
918  VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
919  Value laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr);
920 
921  for (unsigned i = 0; i < vectorType.getShape()[0]; i++) {
922  Value logicalValueId = rewriter.create<arith::ConstantOp>(
923  loc, rewriter.getIndexType(),
924  rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
925  FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
926  rewriter, op.getLoc(), *warpMatrixInfo);
927  if (failed(coords))
928  return rewriter.notifyMatchFailure(op, "no coords");
929 
930  Value el =
931  rewriter.create<vector::ExtractOp>(loc, matrix, ArrayRef<int64_t>{i});
932  SmallVector<Value, 4> newIndices;
933  getXferIndices<vector::TransferWriteOp>(
934  rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
935  rewriter.create<vector::StoreOp>(loc, el, op.getSource(), newIndices);
936  }
937 
938  LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
939  rewriter.eraseOp(op);
940  return success();
941 }
942 
943 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
944  SmallVectorImpl<int64_t> &results) {
945  for (auto attr : arrayAttr)
946  results.push_back(cast<IntegerAttr>(attr).getInt());
947 }
948 
949 static LogicalResult
951  vector::ExtractStridedSliceOp op,
952  llvm::DenseMap<Value, Value> &valueMapping) {
953  OpBuilder::InsertionGuard g(rewriter);
954  rewriter.setInsertionPoint(op);
955 
956  Location loc = op->getLoc();
957 
958  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
960  if (failed(warpMatrixInfo))
961  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
962 
963  FailureOr<nvgpu::FragmentElementInfo> mmaSyncFragmentInfo =
964  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
965  if (failed(mmaSyncFragmentInfo))
966  return rewriter.notifyMatchFailure(op, "no mmaSyncFragmentInfo");
967 
968  // Find the vector.transer_read whose result vector is being sliced.
969  auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>();
970  if (!transferReadOp)
971  return rewriter.notifyMatchFailure(op, "no transfer read");
972 
973  warpMatrixInfo = nvgpu::getWarpMatrixInfo(transferReadOp);
974  if (failed(warpMatrixInfo))
975  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
976 
977  FailureOr<nvgpu::FragmentElementInfo> ldFragmentInfo =
978  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
979  if (failed(ldFragmentInfo))
980  return rewriter.notifyMatchFailure(op, "no ldFragmentInfo");
981 
982  assert(
983  (mmaSyncFragmentInfo->elementsPerRegister ==
984  ldFragmentInfo->elementsPerRegister) &&
985  "Number of elements per register should be same for load and mma.sync");
986 
987  // Create vector.extract_strided_slice op for thread-owned fragments.
988  std::array<int64_t, 2> strides = {1,
989  1}; // stride for extract slice is always 1.
990  std::array<int64_t, 2> sliceShape = {
991  mmaSyncFragmentInfo->numRegistersPerFragment,
992  mmaSyncFragmentInfo->elementsPerRegister};
993  auto it = valueMapping.find(transferReadOp);
994  if (it == valueMapping.end())
995  return rewriter.notifyMatchFailure(op, "no mapping");
996  auto sourceVector = it->second;
997 
998  // offset and sizes at warp-level of onwership.
999  SmallVector<int64_t> offsets;
1000  populateFromInt64AttrArray(op.getOffsets(), offsets);
1001 
1002  SmallVector<int64_t> sizes;
1003  populateFromInt64AttrArray(op.getSizes(), sizes);
1004  ArrayRef<int64_t> warpVectorShape = op.getSourceVectorType().getShape();
1005 
1006  // Compute offset in vector registers. Note that the mma.sync vector registers
1007  // are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector
1008  // registers can only be sliced along numberOfFragments, i.e., sliceOffset[0].
1009  std::array<int64_t, 2> sliceOffset = {0, 0};
1010 
1011  if (offsets[0] && offsets[1])
1012  return op->emitError() << "Slicing fragments in 2D is not supported. ";
1013  if (offsets[0])
1014  sliceOffset[0] = (warpVectorShape[0] / offsets[0]);
1015  else if (offsets[1])
1016  sliceOffset[0] = (warpVectorShape[1] / offsets[1]);
1017 
1018  Value newOp = rewriter.create<vector::ExtractStridedSliceOp>(
1019  loc, sourceVector, sliceOffset, sliceShape, strides);
1020 
1021  valueMapping[op] = newOp;
1022  return success();
1023 }
1024 
1025 static LogicalResult
1026 convertContractOp(RewriterBase &rewriter, vector::ContractionOp op,
1027  llvm::DenseMap<Value, Value> &valueMapping) {
1028  OpBuilder::InsertionGuard g(rewriter);
1029  rewriter.setInsertionPoint(op);
1030 
1031  auto itA = valueMapping.find(op.getLhs());
1032  auto itB = valueMapping.find(op.getRhs());
1033  auto itC = valueMapping.find(op.getAcc());
1034  if (itA == valueMapping.end() || itB == valueMapping.end() ||
1035  itC == valueMapping.end())
1036  return rewriter.notifyMatchFailure(op, "no mapping");
1037  Value opA = itA->second, opB = itB->second, opC = itC->second;
1038  Value matmul = rewriter.create<gpu::SubgroupMmaComputeOp>(
1039  op.getLoc(), opC.getType(), opA, opB, opC, /*a_transpose=*/UnitAttr(),
1040  /*b_transpose=*/UnitAttr());
1041  valueMapping[op.getResult()] = matmul;
1042  return success();
1043 }
1044 
1045 static LogicalResult
1046 convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op,
1047  llvm::DenseMap<Value, Value> &valueMapping) {
1048  OpBuilder::InsertionGuard g(rewriter);
1049  rewriter.setInsertionPoint(op);
1050 
1051  auto itA = valueMapping.find(op.getLhs());
1052  auto itB = valueMapping.find(op.getRhs());
1053  auto itC = valueMapping.find(op.getAcc());
1054  if (itA == valueMapping.end() || itB == valueMapping.end() ||
1055  itC == valueMapping.end())
1056  return rewriter.notifyMatchFailure(op, "no mapping");
1057  Value opA = itA->second, opB = itB->second, opC = itC->second;
1058  int64_t m = cast<VectorType>(op.getLhs().getType()).getShape()[0];
1059  int64_t n = cast<VectorType>(op.getRhs().getType()).getShape()[0];
1060  int64_t k = cast<VectorType>(op.getLhs().getType()).getShape()[1];
1061  Value matmul = rewriter.create<nvgpu::MmaSyncOp>(
1062  op.getLoc(), opA, opB, opC, rewriter.getI64ArrayAttr({m, n, k}));
1063  valueMapping[op.getResult()] = matmul;
1064  return success();
1065 }
1066 
1067 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
1068 static LogicalResult
1069 convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op,
1070  llvm::DenseMap<Value, Value> &valueMapping) {
1071  OpBuilder::InsertionGuard g(rewriter);
1072  rewriter.setInsertionPoint(op);
1073 
1074  assert(constantSupportsMMAMatrixType(op));
1075 
1076  auto splat =
1077  cast<SplatElementsAttr>(op.getValue()).getSplatValue<TypedAttr>();
1078  auto scalarConstant =
1079  rewriter.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat);
1080  const char *fragType = inferFragType(op);
1081  auto vecType = cast<VectorType>(op.getType());
1083  vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1084  auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(
1085  op.getLoc(), type, scalarConstant);
1086  valueMapping[op.getResult()] = matrix;
1087  return success();
1088 }
1089 
1090 /// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
1091 static LogicalResult
1092 convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op,
1093  llvm::DenseMap<Value, Value> &valueMapping) {
1094  OpBuilder::InsertionGuard g(rewriter);
1095  rewriter.setInsertionPoint(op);
1096 
1097  assert(broadcastSupportsMMAMatrixType(op));
1098 
1099  const char *fragType = inferFragType(op);
1100  auto vecType = op.getResultVectorType();
1102  vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1103  auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(
1104  op.getLoc(), type, op.getSource());
1105  valueMapping[op.getResult()] = matrix;
1106  return success();
1107 }
1108 
1109 // Replace ForOp with a new ForOp with extra operands. The YieldOp is not
1110 // updated and needs to be updated separately for the loop to be correct.
1111 static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
1112  scf::ForOp loop,
1113  ValueRange newInitArgs) {
1114  OpBuilder::InsertionGuard g(rewriter);
1115  rewriter.setInsertionPoint(loop);
1116 
1117  // Create a new loop before the existing one, with the extra operands.
1118  rewriter.setInsertionPoint(loop);
1119  auto operands = llvm::to_vector<4>(loop.getInitArgs());
1120  llvm::append_range(operands, newInitArgs);
1121  scf::ForOp newLoop = rewriter.create<scf::ForOp>(
1122  loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
1123  operands);
1124  rewriter.eraseBlock(newLoop.getBody());
1125 
1126  newLoop.getRegion().getBlocks().splice(
1127  newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks());
1128  for (Value operand : newInitArgs)
1129  newLoop.getBody()->addArgument(operand.getType(), operand.getLoc());
1130 
1131  for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
1132  loop.getNumResults())))
1133  rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
1134 
1135  LLVM_DEBUG(DBGS() << "newLoop now: " << newLoop << "\n");
1136  LLVM_DEBUG(DBGS() << "stripped scf.for: " << loop << "\n");
1137  LLVM_DEBUG(DBGS() << "erase: " << loop);
1138 
1139  rewriter.eraseOp(loop);
1140  return newLoop;
1141 }
1142 
1143 static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op,
1144  llvm::DenseMap<Value, Value> &valueMapping) {
1145  OpBuilder::InsertionGuard g(rewriter);
1146  rewriter.setInsertionPoint(op);
1147 
1148  SmallVector<Value> newOperands;
1150  for (const auto &operand : llvm::enumerate(op.getInitArgs())) {
1151  auto it = valueMapping.find(operand.value());
1152  if (it == valueMapping.end()) {
1153  LLVM_DEBUG(DBGS() << "no value mapping for: " << operand.value() << "\n");
1154  continue;
1155  }
1156  argMapping.push_back(std::make_pair(
1157  operand.index(), op.getInitArgs().size() + newOperands.size()));
1158  newOperands.push_back(it->second);
1159  }
1160 
1161  scf::ForOp newForOp = replaceForOpWithNewSignature(rewriter, op, newOperands);
1162  Block &loopBody = *newForOp.getBody();
1163  for (auto mapping : argMapping) {
1164  valueMapping[newForOp.getResult(mapping.first)] =
1165  newForOp.getResult(mapping.second);
1166  valueMapping[loopBody.getArgument(mapping.first +
1167  newForOp.getNumInductionVars())] =
1168  loopBody.getArgument(mapping.second + newForOp.getNumInductionVars());
1169  }
1170 
1171  LLVM_DEBUG(DBGS() << "scf.for to: " << newForOp << "\n");
1172  return success();
1173 }
1174 
1175 static LogicalResult
1176 convertYieldOp(RewriterBase &rewriter, scf::YieldOp op,
1177  llvm::DenseMap<Value, Value> &valueMapping) {
1178  OpBuilder::InsertionGuard g(rewriter);
1179  rewriter.setInsertionPoint(op);
1180 
1181  auto loop = cast<scf::ForOp>(op->getParentOp());
1182  auto yieldOperands = llvm::to_vector<4>(op.getOperands());
1183  for (const auto &operand : llvm::enumerate(op.getOperands())) {
1184  auto it = valueMapping.find(operand.value());
1185  if (it == valueMapping.end())
1186  continue;
1187  // Replace the yield of old value with the for op argument to make it easier
1188  // to remove the dead code.
1189  yieldOperands[operand.index()] = loop.getInitArgs()[operand.index()];
1190  yieldOperands.push_back(it->second);
1191  }
1192  rewriter.create<scf::YieldOp>(op.getLoc(), yieldOperands);
1193 
1194  LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
1195  rewriter.eraseOp(op);
1196  return success();
1197 }
1198 
1199 /// Convert an elementwise op to the equivalent elementwise op on MMA matrix.
1200 static LogicalResult
1202  gpu::MMAElementwiseOp opType,
1203  llvm::DenseMap<Value, Value> &valueMapping) {
1204  OpBuilder::InsertionGuard g(rewriter);
1205  rewriter.setInsertionPoint(op);
1206 
1207  SmallVector<Value> matrixOperands;
1208  for (Value operand : op->getOperands()) {
1209  auto it = valueMapping.find(operand);
1210  if (it == valueMapping.end())
1211  return rewriter.notifyMatchFailure(op, "no mapping");
1212  matrixOperands.push_back(it->second);
1213  }
1214  auto resultType = cast<gpu::MMAMatrixType>(matrixOperands[0].getType());
1215  if (opType == gpu::MMAElementwiseOp::EXTF) {
1216  // The floating point extension case has a different result type.
1217  auto vectorType = cast<VectorType>(op->getResultTypes()[0]);
1218  resultType = gpu::MMAMatrixType::get(resultType.getShape(),
1219  vectorType.getElementType(),
1220  resultType.getOperand());
1221  }
1222 
1223  Value newOp = rewriter.create<gpu::SubgroupMmaElementwiseOp>(
1224  op->getLoc(), resultType, matrixOperands, opType);
1225  valueMapping[op->getResult(0)] = newOp;
1226  return success();
1227 }
1228 
1230  bool useNvGpu) {
1231  if (!useNvGpu) {
1232  patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>(
1233  patterns.getContext());
1234  return;
1235  }
1237  patterns.add<CombineTransferReadOpTranspose>(patterns.getContext());
1238 }
1239 
1241  Operation *rootOp) {
1242  SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/false);
1243  llvm::DenseMap<Value, Value> valueMapping;
1244 
1245  auto globalRes = LogicalResult::success();
1246  for (Operation *op : ops) {
1247  LLVM_DEBUG(DBGS() << "Process op: " << *op << "\n");
1248  // Apparently callers do not want to early exit on failure here.
1249  auto res = LogicalResult::success();
1250  if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
1251  res = convertTransferReadOp(rewriter, transferRead, valueMapping);
1252  } else if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) {
1253  res = convertTransferWriteOp(rewriter, transferWrite, valueMapping);
1254  } else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) {
1255  res = convertContractOp(rewriter, contractOp, valueMapping);
1256  } else if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
1257  res = convertConstantOp(rewriter, constantOp, valueMapping);
1258  } else if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) {
1259  res = convertBroadcastOp(rewriter, broadcastOp, valueMapping);
1260  } else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
1261  res = convertForOp(rewriter, forOp, valueMapping);
1262  } else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
1263  res = convertYieldOp(rewriter, yieldOp, valueMapping);
1264  } else if (auto elementwiseType = convertElementwiseOpToMMA(op)) {
1265  res = convertElementwiseOp(rewriter, op, *elementwiseType, valueMapping);
1266  }
1267  if (failed(res))
1268  globalRes = failure();
1269  }
1270  return globalRes;
1271 }
1272 
1274  Operation *rootOp) {
1275  SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/true);
1276  llvm::DenseMap<Value, Value> valueMapping;
1277  for (Operation *op : ops) {
1279  .Case([&](vector::TransferReadOp transferReadOp) {
1280  return convertTransferReadToLoads(rewriter, transferReadOp,
1281  valueMapping);
1282  })
1283  .Case([&](vector::TransferWriteOp transferWriteOp) {
1284  return convertTransferWriteToStores(rewriter, transferWriteOp,
1285  valueMapping);
1286  })
1287  .Case([&](vector::ExtractStridedSliceOp extractStridedSliceOp) {
1288  return convertExtractStridedSlice(rewriter, extractStridedSliceOp,
1289  valueMapping);
1290  })
1291  .Case([&](vector::ContractionOp contractionOp) {
1292  return convertContractOpToMmaSync(rewriter, contractionOp,
1293  valueMapping);
1294  })
1295  .Case([&](scf::ForOp forOp) {
1296  return convertForOp(rewriter, forOp, valueMapping);
1297  })
1298  .Case([&](scf::YieldOp yieldOp) {
1299  return convertYieldOp(rewriter, yieldOp, valueMapping);
1300  })
1301  .Case([&](arith::ConstantOp constOp) {
1302  return convertConstantOpMmaSync(rewriter, constOp, valueMapping);
1303  })
1304  .Default([&](Operation *op) {
1305  return op->emitError() << "unhandled vector to mma type: " << *op;
1306  })
1307  .failed()) {
1308  return op->emitOpError()
1309  << "failed to convert op during vector-to-nvgpu conversion";
1310  }
1311  }
1312  return success();
1313 }
1314 
1315 namespace {
1316 
1317 struct ConvertVectorToGPUPass
1318  : public impl::ConvertVectorToGPUBase<ConvertVectorToGPUPass> {
1319 
1320  explicit ConvertVectorToGPUPass(bool useNvGpu_) {
1321  useNvGpu.setValue(useNvGpu_);
1322  }
1323 
1324  void runOnOperation() override {
1325  RewritePatternSet patterns(&getContext());
1326  populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue());
1327  if (failed(
1328  applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
1329  return signalPassFailure();
1330 
1331  IRRewriter rewriter(&getContext());
1332  if (useNvGpu) {
1333  if (failed(
1334  convertVectorToNVVMCompatibleMMASync(rewriter, getOperation())))
1335  return signalPassFailure();
1336  return;
1337  }
1338  (void)convertVectorToMMAOps(rewriter, getOperation());
1339  }
1340 };
1341 
1342 } // namespace
1343 
1344 std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass(bool useNvGpu) {
1345  return std::make_unique<ConvertVectorToGPUPass>(useNvGpu);
1346 }
static MLIRContext * getContext(OpFoldResult val)
#define SUBI(lhs, rhs)
Definition: LoopEmitter.cpp:37
#define MULI(lhs, rhs)
Definition: LoopEmitter.cpp:38
#define ADDI(lhs, rhs)
Definition: LoopEmitter.cpp:35
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
static LogicalResult convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertContractOp(RewriterBase &rewriter, vector::ContractionOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool isSharedMemory(MemRefType type)
Return true if this is a shared memory memref type.
static VectorType getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo &regInfo)
Returns the vector type which represents a matrix fragment.
static const char * inferFragType(Operation *op)
static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp)
static bool supportsMMaMatrixType(Operation *op, bool useNvGpu)
static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp)
Return true if the constant is a splat to a 2D vector so that it can be converted to a MMA constant m...
static bool contractSupportsMMAMatrixType(vector::ContractionOp contract, bool useNvGpu)
Definition: VectorToGPU.cpp:77
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
static bool isTransposeMatrixLoadMap(AffineMap permutationMap)
static SetVector< Operation * > getSliceContract(Operation *op, const BackwardSliceOptions &backwardSliceOptions, const ForwardSliceOptions &forwardSliceOptions)
Return an unsorted slice handling scf.for region differently than getSlice.
static LogicalResult convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp)
Return true if this integer extend op can be folded into a contract op.
static LogicalResult convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
Converts a vector.transfer_read operation directly to either a vector.load or a nvgpu....
static LogicalResult convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
static LogicalResult creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
static FailureOr< bool > isTransposed(vector::TransferReadOp op)
Check if the loaded matrix operand requires transposed.
static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, scf::ForOp loop, ValueRange newInitArgs)
static bool transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp)
static std::optional< gpu::MMAElementwiseOp > convertElementwiseOpToMMA(Operation *op)
Return the MMA elementwise enum associated with op if it is supported.
static LogicalResult convertElementwiseOp(RewriterBase &rewriter, Operation *op, gpu::MMAElementwiseOp opType, llvm::DenseMap< Value, Value > &valueMapping)
Convert an elementwise op to the equivalent elementwise op on MMA matrix.
static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp)
static bool extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op)
Returns true if the extract strided slice op is supported with mma.sync path.
static LogicalResult convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static bool elementwiseSupportsMMAMatrixType(Operation *op)
Return true if the op is supported as elementwise op on MMAMatrix type.
static SetVector< Operation * > getOpToConvert(mlir::Operation *op, bool useNvGpu)
static LogicalResult convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static LogicalResult convertYieldOp(RewriterBase &rewriter, scf::YieldOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op, llvm::DenseMap< Value, Value > &valueMapping)
#define DBGS()
Definition: VectorToGPU.cpp:40
static std::optional< int64_t > getStaticallyKnownRowStride(ShapedType type)
static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp, AffineMap offsetMap, ArrayRef< Value > dimValues, SmallVector< Value, 4 > &indices)
For a vector TransferOpType xferOp, an empty indices vector, and an AffineMap representing offsets to...
Definition: VectorToGPU.cpp:57
static LogicalResult convertExtractStridedSlice(RewriterBase &rewriter, vector::ExtractStridedSliceOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp)
Return true if this is a broadcast from scalar to a 2D vector.
static LogicalResult createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
MLIRContext * getContext() const
Definition: AffineMap.cpp:330
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
Definition: AffineMap.cpp:155
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
Definition: AffineMap.cpp:381
unsigned getNumResults() const
Definition: AffineMap.cpp:389
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:398
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:251
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:543
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:299
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:31
BlockArgument getArgument(unsigned i)
Definition: Block.h:127
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
UnitAttr getUnitAttr()
Definition: Builders.cpp:114
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:379
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:371
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:288
IndexType getIndexType()
Definition: Builders.cpp:71
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:325
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:766
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:845
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:869
user_iterator user_begin()
Definition: Operation.h:865
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition: GPUDialect.h:130
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Definition: GPUDialect.cpp:121
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1142
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
int64_t inferTileWidthInBits(const WarpMatrixInfo &type)
Returns the number of bits in a single tile row.
Definition: MMAUtils.cpp:87
FailureOr< vector::ContractionOp > getUserContract(Operation *op)
Returns the first user of the op that is vector.contract.
Definition: MMAUtils.cpp:50
FailureOr< AffineMap > getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc, const WarpMatrixInfo &fragmentType)
Returns an AffineMap which maps a two dimensions representing (laneId, logicalValueId) and returns tw...
Definition: MMAUtils.cpp:173
FailureOr< WarpMatrixInfo > getWarpMatrixInfo(Operation *op)
If op is a vector.transfer_write, return the WarpMatrixInfo for the vector operand.
Definition: MMAUtils.cpp:58
FailureOr< AffineMap > getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc, const LdMatrixParams &params)
Returns an AffineMap which maps a single dimension representing the laneId to two results representin...
Definition: MMAUtils.cpp:238
FailureOr< LdMatrixParams > getLdMatrixParams(const WarpMatrixInfo &type, bool transpose)
Given type that contains info for a warp-matrix operand and whether or not the load is a transposed l...
Definition: MMAUtils.cpp:209
FailureOr< FragmentElementInfo > getMmaSyncRegisterType(const WarpMatrixInfo &type)
Returns a FragmentElementInfo struct describing the register types for the given matrix fragment type...
Definition: MMAUtils.cpp:100
bool canLowerToWarpMatrixOperation(vector::TransferReadOp op)
Returns whether the vector.transfer_read instruction can be interpreted as a warp-level cooperative m...
Definition: MMAUtils.cpp:276
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Definition: VectorOps.h:144
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Definition: VectorOps.h:139
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction wit...
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Definition: XeGPUOps.cpp:21
Include the generated interface declarations.
void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, bool useNvGpu=false)
Patterns to transform vector ops into a canonical form to convert to MMA matrix operations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:348
void getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter, Operation *rootOp)
Convert vector ops ops nested under rootOp to vector and GPU operaitons compatible with the nvvm....
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:630
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::unique_ptr< Pass > createConvertVectorToGPUPass(bool useNvGpu=false)
Convert from vector to GPU ops.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:606
LogicalResult convertVectorToMMAOps(RewriterBase &rewriter, Operation *rootOp)
Convert vector ops to MMA matrix operations nested under rootOp.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This trait tags element-wise ops on vectors or tensors.
TransitiveFilter filter
Definition: SliceAnalysis.h:30
Specifies information about the registers which compose a matrix fragment according to the PTX docume...
Definition: MMAUtils.h:52