MLIR  21.0.0git
VectorToGPU.cpp
Go to the documentation of this file.
1 //===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of vector operations to GPU dialect ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include <type_traits>
16 
30 #include "mlir/IR/Builders.h"
31 #include "mlir/IR/BuiltinOps.h"
32 #include "mlir/IR/Region.h"
33 #include "mlir/Pass/Pass.h"
35 #include "mlir/Transforms/Passes.h"
36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/TypeSwitch.h"
38 
39 #define DEBUG_TYPE "vector-to-gpu"
40 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
41 #define DBGSNL() (llvm::dbgs() << "\n")
42 
43 namespace mlir {
44 #define GEN_PASS_DEF_CONVERTVECTORTOGPU
45 #include "mlir/Conversion/Passes.h.inc"
46 } // namespace mlir
47 
48 using namespace mlir;
49 
50 /// For a vector TransferOpType `xferOp`, an empty `indices` vector, and an
51 /// AffineMap representing offsets to apply to indices, the function fills
52 /// `indices` with the original indices plus the offsets. The offsets are
53 /// applied by taking into account the permutation map of the transfer op. If
54 /// the `offsetMap` has dimension placeholders, those should be provided in
55 /// `dimValues`.
56 template <typename TransferOpType>
57 static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp,
58  AffineMap offsetMap, ArrayRef<Value> dimValues,
59  SmallVector<Value, 4> &indices) {
60  indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end());
61  Location loc = xferOp.getLoc();
62  unsigned offsetsIdx = 0;
63  for (auto expr : xferOp.getPermutationMap().getResults()) {
64  if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
65  Value prevIdx = indices[dim.getPosition()];
66  SmallVector<OpFoldResult, 3> dims(dimValues);
67  dims.push_back(prevIdx);
68  AffineExpr d0 = rewriter.getAffineDimExpr(offsetMap.getNumDims());
69  indices[dim.getPosition()] = affine::makeComposedAffineApply(
70  rewriter, loc, d0 + offsetMap.getResult(offsetsIdx++), dims);
71  continue;
72  }
73  }
74 }
75 
76 // Return true if the contract op can be convert to MMA matmul.
77 static bool contractSupportsMMAMatrixType(vector::ContractionOp contract,
78  bool useNvGpu) {
79  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
80  auto infer = [&](MapList m) {
81  return AffineMap::inferFromExprList(m, contract.getContext());
82  };
83  AffineExpr m, n, k;
84  bindDims(contract.getContext(), m, n, k);
85  auto iteratorTypes = contract.getIteratorTypes().getValue();
86  if (!(vector::isParallelIterator(iteratorTypes[0]) &&
87  vector::isParallelIterator(iteratorTypes[1]) &&
88  vector::isReductionIterator(iteratorTypes[2])))
89  return false;
90 
91  // The contract needs to represent a matmul to be able to convert to
92  // MMAMatrix matmul.
93  if (!useNvGpu &&
94  contract.getIndexingMapsArray() != infer({{m, k}, {k, n}, {m, n}}))
95  return false;
96  if (useNvGpu &&
97  contract.getIndexingMapsArray() != infer({{m, k}, {n, k}, {m, n}}))
98  return false;
99 
100  return true;
101 }
102 
103 // Return true if the given map represents a transposed matrix load,
104 // i.e. (d0, d1, ...) -> (dn-1, dn-2).
105 static bool isTransposeMatrixLoadMap(AffineMap permutationMap) {
106  MLIRContext *ctx = permutationMap.getContext();
107  // Local OpBuilder is fine here, we just build attributes.
108  OpBuilder b(ctx);
109  auto nDim = permutationMap.getNumDims();
110  AffineExpr zero = b.getAffineConstantExpr(0);
111  if (nDim < 2) {
112  // Support transposed+broadcasted cases: affine_map<(d0) -> (d0, 0)>.
113  AffineExpr dim0 = b.getAffineDimExpr(0);
114  return permutationMap == AffineMap::get(1, 0, {dim0, zero}, ctx);
115  }
116 
117  AffineExpr innerDim = b.getAffineDimExpr(nDim - 1);
118  AffineExpr outerDim = b.getAffineDimExpr(nDim - 2);
119  // Support both transposed and transposed+broadcasted cases.
120  return permutationMap == AffineMap::get(nDim, 0, {innerDim, outerDim}, ctx) ||
121  permutationMap == AffineMap::get(nDim, 0, {innerDim, zero}, ctx);
122 }
123 
124 // Return the stide for the second-to-last dimension of |type| if it is a memref
125 // and has a constant stride.
126 static std::optional<int64_t> getStaticallyKnownRowStride(ShapedType type) {
127  auto memrefType = dyn_cast<MemRefType>(type);
128  if (!memrefType)
129  return false;
130  // If the memref is 0 or 1D the horizontal stride is 0.
131  if (memrefType.getRank() < 2)
132  return 0;
133  int64_t offset = 0;
134  SmallVector<int64_t, 2> strides;
135  if (failed(memrefType.getStridesAndOffset(strides, offset)) ||
136  strides.back() != 1)
137  return std::nullopt;
138  int64_t stride = strides[strides.size() - 2];
139  if (stride == ShapedType::kDynamic)
140  return std::nullopt;
141  return stride;
142 }
143 
144 // Return true if the transfer op can be converted to a MMA matrix load.
145 static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
146  if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
147  readOp.getVectorType().getRank() != 2)
148  return false;
149  if (!getStaticallyKnownRowStride(readOp.getShapedType()))
150  return false;
151 
152  // Only allow integer types if the signedness can be inferred.
153  if (readOp.getVectorType().getElementType().isInteger(8))
154  if (!readOp->hasOneUse() || (!isa<arith::ExtSIOp>(*readOp->user_begin()) &&
155  !isa<arith::ExtUIOp>(*readOp->user_begin())))
156  return false;
157 
158  AffineMap map = readOp.getPermutationMap();
159  MLIRContext *ctx = readOp.getContext();
160  AffineExpr innerDim = getAffineDimExpr(map.getNumDims() - 1, ctx);
161  AffineExpr zero = getAffineConstantExpr(0, ctx);
162  auto broadcastInnerDim =
163  AffineMap::get(map.getNumDims(), 0, {zero, innerDim}, ctx);
164  return map.isMinorIdentity() || map == broadcastInnerDim ||
166 }
167 
168 // Return true if the transfer op can be converted to a MMA matrix store.
169 static bool
170 transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
171  // TODO: support 0-d corner case.
172  if (writeOp.getTransferRank() == 0)
173  return false;
174 
175  if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() ||
176  writeOp.getVectorType().getRank() != 2)
177  return false;
178  if (!getStaticallyKnownRowStride(writeOp.getShapedType()))
179  return false;
180  // TODO: Support transpose once it is added to GPU dialect ops.
181  if (!writeOp.getPermutationMap().isMinorIdentity())
182  return false;
183  return true;
184 }
185 
186 /// Return true if the constant is a splat to a 2D vector so that it can be
187 /// converted to a MMA constant matrix op.
188 static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) {
189  auto vecType = dyn_cast<VectorType>(constantOp.getType());
190  if (!vecType || vecType.getRank() != 2)
191  return false;
192  return isa<SplatElementsAttr>(constantOp.getValue());
193 }
194 
195 /// Return true if this is a broadcast from scalar to a 2D vector.
196 static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {
197  return broadcastOp.getResultVectorType().getRank() == 2;
198 }
199 
200 /// Return true if this integer extend op can be folded into a contract op.
201 template <typename ExtOpTy>
202 static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) {
203  auto transferReadOp =
204  extOp.getOperand().template getDefiningOp<vector::TransferReadOp>();
205  if (!transferReadOp)
206  return false;
207  return llvm::all_of(extOp->getUsers(), llvm::IsaPred<vector::ContractionOp>);
208 }
209 
210 static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp) { return true; }
211 
212 /// Return the MMA elementwise enum associated with `op` if it is supported.
213 /// Return `std::nullopt` otherwise.
214 static std::optional<gpu::MMAElementwiseOp>
216  if (isa<arith::AddFOp>(op))
217  return gpu::MMAElementwiseOp::ADDF;
218  if (isa<arith::MulFOp>(op))
219  return gpu::MMAElementwiseOp::MULF;
220  if (isa<arith::SubFOp>(op))
221  return gpu::MMAElementwiseOp::SUBF;
222  if (isa<arith::MaximumFOp>(op))
223  return gpu::MMAElementwiseOp::MAXF;
224  if (isa<arith::MinimumFOp>(op))
225  return gpu::MMAElementwiseOp::MINF;
226  if (isa<arith::DivFOp>(op))
227  return gpu::MMAElementwiseOp::DIVF;
228  if (isa<arith::AddIOp>(op))
230  if (isa<arith::MulIOp>(op))
232  if (isa<arith::SubIOp>(op))
234  if (isa<arith::DivSIOp>(op))
235  return gpu::MMAElementwiseOp::DIVS;
236  if (isa<arith::DivUIOp>(op))
237  return gpu::MMAElementwiseOp::DIVU;
238  if (isa<arith::NegFOp>(op))
239  return gpu::MMAElementwiseOp::NEGATEF;
240  if (isa<arith::ExtFOp>(op))
241  return gpu::MMAElementwiseOp::EXTF;
242  return std::nullopt;
243 }
244 
245 /// Return true if the op is supported as elementwise op on MMAMatrix type.
247  return convertElementwiseOpToMMA(op).has_value();
248 }
249 
250 /// Returns true if the extract strided slice op is supported with `mma.sync`
251 /// path.
252 static bool
253 extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op) {
254 
255  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
257  if (failed(warpMatrixInfo))
258  return false;
259 
260  FailureOr<vector::ContractionOp> contractOp = nvgpu::getUserContract(op);
261  if (failed(contractOp))
262  return false;
263 
264  // Handle vector.extract_strided_slice on registers containing
265  // matrixB and matrixC operands. vector.extract_strided_slice op
266  // is not supported on registers containing matrixA operands.
267  if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B)
268  return (cast<VectorType>(op->getResult(0).getType()) ==
269  cast<VectorType>((*contractOp).getRhs().getType()));
270  if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::C)
271  return (cast<VectorType>(op->getResult(0).getType()) ==
272  cast<VectorType>((*contractOp).getAcc().getType()));
273 
274  return false;
275 }
276 
277 static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
278  if (isa<scf::ForOp, scf::YieldOp>(op))
279  return true;
280  if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
281  return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferRead)
282  : transferReadSupportsMMAMatrixType(transferRead);
283  if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
284  return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferWrite)
285  : transferWriteSupportsMMAMatrixType(transferWrite);
286  if (auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op))
287  return useNvGpu &&
288  extractStridedSliceSupportsMMAMatrixType(extractStridedSlice);
289  if (auto contract = dyn_cast<vector::ContractionOp>(op))
290  return contractSupportsMMAMatrixType(contract, useNvGpu);
291  if (auto constant = dyn_cast<arith::ConstantOp>(op))
292  return constantSupportsMMAMatrixType(constant);
293  if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
295  if (auto signedExtend = dyn_cast<arith::ExtSIOp>(op))
296  return integerExtendSupportsMMAMatrixType<arith::ExtSIOp>(signedExtend);
297  if (auto unsignedExtend = dyn_cast<arith::ExtUIOp>(op))
298  return integerExtendSupportsMMAMatrixType<arith::ExtUIOp>(unsignedExtend);
299  if (auto fpExtend = dyn_cast<arith::ExtFOp>(op))
300  return fpExtendSupportsMMAMatrixType(fpExtend);
302 }
303 
304 /// Return an unsorted slice handling scf.for region differently than
305 /// `getSlice`. In scf.for we only want to include as part of the slice elements
306 /// that are part of the use/def chain.
309  const BackwardSliceOptions &backwardSliceOptions,
310  const ForwardSliceOptions &forwardSliceOptions) {
312  slice.insert(op);
313  unsigned currentIndex = 0;
314  SetVector<Operation *> backwardSlice;
315  SetVector<Operation *> forwardSlice;
316  while (currentIndex != slice.size()) {
317  auto *currentOp = (slice)[currentIndex];
318  // Compute and insert the backwardSlice starting from currentOp.
319  backwardSlice.clear();
320  getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
321  slice.insert_range(backwardSlice);
322 
323  // Compute and insert the forwardSlice starting from currentOp.
324  forwardSlice.clear();
325  // Special case for ForOp, we don't want to include the whole region but
326  // only the value using the region arguments.
327  // TODO: We should refine this to only care about the region arguments being
328  // converted to matrix type.
329  if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) {
330  for (Value forOpResult : forOp.getResults())
331  getForwardSlice(forOpResult, &forwardSlice, forwardSliceOptions);
332  for (BlockArgument &arg : forOp.getRegionIterArgs())
333  getForwardSlice(arg, &forwardSlice, forwardSliceOptions);
334  } else {
335  getForwardSlice(currentOp, &forwardSlice, forwardSliceOptions);
336  }
337  slice.insert_range(forwardSlice);
338  ++currentIndex;
339  }
340  return slice;
341 }
342 
343 // Analyze slice of operations based on convert op to figure out if the whole
344 // slice can be converted to MMA operations.
346  bool useNvGpu) {
347  auto hasVectorDest = [](Operation *op) {
348  return llvm::any_of(op->getResultTypes(), llvm::IsaPred<VectorType>);
349  };
350  BackwardSliceOptions backwardSliceOptions;
351  backwardSliceOptions.filter = hasVectorDest;
352 
353  auto hasVectorSrc = [](Operation *op) {
354  return llvm::any_of(op->getOperandTypes(), llvm::IsaPred<VectorType>);
355  };
356  ForwardSliceOptions forwardSliceOptions;
357  forwardSliceOptions.filter = hasVectorSrc;
358 
359  SetVector<Operation *> opToConvert;
360  op->walk([&](vector::ContractionOp contract) {
361  if (opToConvert.contains(contract.getOperation()))
362  return;
363  SetVector<Operation *> dependentOps =
364  getSliceContract(contract, backwardSliceOptions, forwardSliceOptions);
365  // If any instruction cannot use MMA matrix type drop the whole
366  // chain. MMA matrix are stored in an opaque type so they cannot be used
367  // by all operations.
368  if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) {
369  if (!supportsMMaMatrixType(op, useNvGpu)) {
370  LLVM_DEBUG(DBGS() << "cannot convert op: " << *op << "\n");
371  return true;
372  }
373  return false;
374  }))
375  return;
376 
377  opToConvert.insert_range(dependentOps);
378  });
379  // Sort the operations so that we can convert them in topological order.
380  return topologicalSort(opToConvert);
381 }
382 
383 namespace {
384 // Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted
385 // to MMA matmul.
386 struct PrepareContractToGPUMMA
387  : public OpRewritePattern<vector::ContractionOp> {
389 
390  LogicalResult matchAndRewrite(vector::ContractionOp op,
391  PatternRewriter &rewriter) const override {
392  Location loc = op.getLoc();
393  Value lhs = op.getLhs(), rhs = op.getRhs(), res = op.getAcc();
394 
395  // Set up the parallel/reduction structure in right form.
396  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
397  auto infer = [&](MapList m) {
398  return AffineMap::inferFromExprList(m, op.getContext());
399  };
400  AffineExpr m, n, k;
401  bindDims(rewriter.getContext(), m, n, k);
402  static constexpr std::array<int64_t, 2> perm = {1, 0};
403  auto iteratorTypes = op.getIteratorTypes().getValue();
404  SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
405  if (!(vector::isParallelIterator(iteratorTypes[0]) &&
406  vector::isParallelIterator(iteratorTypes[1]) &&
407  vector::isReductionIterator(iteratorTypes[2])))
408  return rewriter.notifyMatchFailure(op, "not a gemm contraction");
409  //
410  // Two outer parallel, one inner reduction (matmat flavor).
411  //
412  // This is the classical row-major matmul, nothing to do.
413  if (maps == infer({{m, k}, {k, n}, {m, n}}))
414  return rewriter.notifyMatchFailure(op, "contraction already prepared");
415  if (maps == infer({{m, k}, {n, k}, {m, n}})) {
416  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
417  } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
418  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
419  } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
420  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
421  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
422  } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
423  std::swap(rhs, lhs);
424  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
425  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
426  } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
427  std::swap(rhs, lhs);
428  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
429  } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
430  std::swap(lhs, rhs);
431  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
432  } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
433  std::swap(lhs, rhs);
434  } else {
435  // TODO: llvm_unreachable ?
436  return rewriter.notifyMatchFailure(op, "unexpected contraction case");
437  }
438  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
439  op, lhs, rhs, res,
440  rewriter.getAffineMapArrayAttr(infer({{m, k}, {k, n}, {m, n}})),
441  op.getIteratorTypes());
442  return success();
443  }
444 };
445 
446 // Fold transpose op into the transfer read op. NVGPU mma.sync op only supports
447 // row-, column-, and row-major layout for matrixA, matrixB, and matrixC,
448 // respectively. We can fold the transpose operation when loading the data from
449 // Shared Memory to registers.
450 struct CombineTransferReadOpTranspose final
451  : public OpRewritePattern<vector::TransposeOp> {
453 
454  LogicalResult matchAndRewrite(vector::TransposeOp op,
455  PatternRewriter &rewriter) const override {
456  // Look through integer extend ops.
457  Value source = op.getVector();
458  Type resultType = op.getType();
459  Operation *extOp;
460  if ((extOp = source.getDefiningOp<arith::ExtSIOp>()) ||
461  (extOp = source.getDefiningOp<arith::ExtUIOp>()) ||
462  (extOp = source.getDefiningOp<arith::ExtFOp>())) {
463  source = extOp->getOperand(0);
464  resultType =
465  VectorType::get(cast<VectorType>(resultType).getShape(),
466  cast<VectorType>(source.getType()).getElementType());
467  }
468 
469  auto transferReadOp = source.getDefiningOp<vector::TransferReadOp>();
470  if (!transferReadOp)
471  return rewriter.notifyMatchFailure(op, "no transfer read");
472 
473  // TODO: support 0-d corner case.
474  if (transferReadOp.getTransferRank() == 0)
475  return rewriter.notifyMatchFailure(op, "0-D transfer read");
476 
477  if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim())
478  return rewriter.notifyMatchFailure(op, "not inbounds transfer read");
479 
480  AffineMap permutationMap =
481  AffineMap::getPermutationMap(op.getPermutation(), op.getContext());
482  AffineMap newMap =
483  permutationMap.compose(transferReadOp.getPermutationMap());
484 
485  auto loc = op.getLoc();
486  Value result =
487  rewriter
488  .create<vector::TransferReadOp>(
489  loc, resultType, transferReadOp.getSource(),
490  transferReadOp.getIndices(), AffineMapAttr::get(newMap),
491  transferReadOp.getPadding(), transferReadOp.getMask(),
492  transferReadOp.getInBoundsAttr())
493  .getResult();
494 
495  // Fuse through the integer extend op.
496  if (extOp) {
497  if (isa<arith::ExtSIOp>(extOp))
498  result = rewriter.create<arith::ExtSIOp>(loc, op.getType(), result)
499  .getResult();
500  else if (isa<arith::ExtUIOp>(extOp))
501  result = rewriter.create<arith::ExtUIOp>(loc, op.getType(), result)
502  .getResult();
503  else
504  result = rewriter.create<arith::ExtFOp>(loc, op.getType(), result)
505  .getResult();
506  }
507 
508  rewriter.replaceOp(op, result);
509  return success();
510  }
511 };
512 
513 } // namespace
514 
515 // MMA types have different layout based on how they are used in matmul ops.
516 // Figure the right layout to use by looking at op uses.
517 // TODO: Change the GPU dialect to abstract the layout at the this level and
518 // only care about it during lowering to NVVM.
519 static const char *inferFragType(Operation *op) {
520  // We can have arith.ext ops before reaching contract ops. See through them
521  // and other kinds of elementwise ops.
522  if (op->hasOneUse()) {
523  Operation *userOp = *op->user_begin();
524  if (userOp->hasTrait<OpTrait::Elementwise>())
525  return inferFragType(userOp);
526  }
527 
528  for (Operation *users : op->getUsers()) {
529  auto contract = dyn_cast<vector::ContractionOp>(users);
530  if (!contract)
531  continue;
532  assert(op->getNumResults() == 1);
533  if (contract.getLhs() == op->getResult(0))
534  return "AOp";
535  if (contract.getRhs() == op->getResult(0))
536  return "BOp";
537  }
538  return "COp";
539 }
540 
541 static LogicalResult
542 convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
543  llvm::DenseMap<Value, Value> &valueMapping) {
544  OpBuilder::InsertionGuard g(rewriter);
545  rewriter.setInsertionPoint(op);
546 
547  assert(op.getTransferRank() > 0 && "unexpected 0-d transfer");
549  "expected convertible operation");
550 
551  std::optional<int64_t> stride =
552  getStaticallyKnownRowStride(op.getShapedType());
553  if (!stride.has_value()) {
554  LLVM_DEBUG(DBGS() << "no stride\n");
555  return rewriter.notifyMatchFailure(op, "no stride");
556  }
557 
558  AffineMap map = op.getPermutationMap();
559  bool isTranspose = isTransposeMatrixLoadMap(map);
560 
561  // Handle broadcast by setting the stride to 0.
562  if (auto cstExpr = dyn_cast<AffineConstantExpr>(map.getResult(isTranspose))) {
563  assert(cstExpr.getValue() == 0);
564  stride = 0;
565  }
566 
567  Value mappingResult = op.getResult();
568  auto elType = op.getVectorType().getElementType();
569  const char *fragType = inferFragType(op);
570  if (op->hasOneUse()) {
571  auto *user = *op->user_begin();
572  // Infer the signedness of the mma type from the integer extend.
573  if (isa<arith::ExtSIOp, arith::ExtUIOp>(user)) {
574  elType = IntegerType::get(
575  op.getContext(), cast<IntegerType>(elType).getWidth(),
576  isa<arith::ExtSIOp>(user) ? IntegerType::Signed
577  : IntegerType::Unsigned);
578  mappingResult = user->getResult(0);
579  }
580  }
581  gpu::MMAMatrixType type =
582  gpu::MMAMatrixType::get(op.getVectorType().getShape(), elType, fragType);
583  Value load = rewriter.create<gpu::SubgroupMmaLoadMatrixOp>(
584  op.getLoc(), type, op.getSource(), op.getIndices(),
585  rewriter.getIndexAttr(*stride),
586  isTranspose ? rewriter.getUnitAttr() : UnitAttr());
587  valueMapping[mappingResult] = load;
588 
589  LLVM_DEBUG(DBGS() << "transfer read to: " << load << "\n");
590  return success();
591 }
592 
593 static LogicalResult
594 convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
595  llvm::DenseMap<Value, Value> &valueMapping) {
596  OpBuilder::InsertionGuard g(rewriter);
597  rewriter.setInsertionPoint(op);
598 
600  std::optional<int64_t> stride =
601  getStaticallyKnownRowStride(op.getShapedType());
602  if (!stride.has_value()) {
603  LLVM_DEBUG(DBGS() << "no stride\n");
604  return rewriter.notifyMatchFailure(op, "no stride");
605  }
606 
607  auto it = valueMapping.find(op.getVector());
608  if (it == valueMapping.end()) {
609  LLVM_DEBUG(DBGS() << "no mapping\n");
610  return rewriter.notifyMatchFailure(op, "no mapping");
611  }
612 
613  Value matrix = it->second;
614  auto store = rewriter.create<gpu::SubgroupMmaStoreMatrixOp>(
615  op.getLoc(), matrix, op.getSource(), op.getIndices(),
616  rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr());
617  (void)store;
618 
619  LLVM_DEBUG(DBGS() << "transfer write to: " << store << "\n");
620 
621  LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
622  rewriter.eraseOp(op);
623  return success();
624 }
625 
626 /// Returns the vector type which represents a matrix fragment.
627 static VectorType
630  regInfo.elementsPerRegister};
631  Type elType = regInfo.registerLLVMType;
632  if (auto vecType = dyn_cast<VectorType>(elType))
633  elType = vecType.getElementType();
634  return VectorType::get(shape, elType);
635 }
636 
637 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
638 static LogicalResult
639 convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op,
640  llvm::DenseMap<Value, Value> &valueMapping) {
641  OpBuilder::InsertionGuard g(rewriter);
642  rewriter.setInsertionPoint(op);
643 
644  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
646  if (failed(warpMatrixInfo)) {
647  LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
648  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
649  }
650 
651  FailureOr<nvgpu::FragmentElementInfo> regInfo =
652  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
653  if (failed(regInfo)) {
654  LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
655  return rewriter.notifyMatchFailure(op, "not mma sync reg info");
656  }
657 
658  VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
659  auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
660  if (!dense) {
661  LLVM_DEBUG(DBGS() << "not a splat\n");
662  return rewriter.notifyMatchFailure(op, "not a splat");
663  }
664 
665  Value result = rewriter.create<arith::ConstantOp>(
666  op.getLoc(), vectorType,
667  DenseElementsAttr::get(vectorType, dense.getSplatValue<Attribute>()));
668  valueMapping[op.getResult()] = result;
669  return success();
670 }
671 
672 /// Check if the loaded matrix operand requires transposed.
673 /// Transposed Map Example:
674 /// Example 1 : (..., d0, d1) -> (d1 * 1, d0 * 2)
675 /// Example 2 : (d0, d1, d2, d3) -> (d3, d2)
676 /// The code below checks if the output 2D is transposed using a generalized
677 /// version : (d0, d1, dn, ..., dm, ...) -> (dm, dn)
678 /// Returns : true; if m > n, false o.w.
679 static FailureOr<bool> isTransposed(vector::TransferReadOp op) {
681 
682  if (map.getNumResults() != 2) {
683  LLVM_DEBUG(DBGS() << "Failed because the result of `vector.transfer_read` "
684  "is not a 2d operand\n");
685  return failure();
686  }
687 
688  // Output 2D matrix dimensions in the order of d0, d1.
689  mlir::AffineExpr dM = map.getResult(0);
690  mlir::AffineExpr dN = map.getResult(1);
691 
692  // Find the position of these expressions in the input.
693  auto exprM = dyn_cast<AffineDimExpr>(dM);
694  auto exprN = dyn_cast<AffineDimExpr>(dN);
695 
696  if (!exprM || !exprN) {
697  LLVM_DEBUG(DBGS() << "Failed because expressions are not affine dim "
698  "expressions, then transpose cannot be determined.\n");
699  return failure();
700  }
701 
702  return exprM.getPosition() > exprN.getPosition();
703 }
704 
705 static LogicalResult
706 creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
707  llvm::DenseMap<Value, Value> &valueMapping) {
708  OpBuilder::InsertionGuard g(rewriter);
709  rewriter.setInsertionPoint(op);
710  Location loc = op->getLoc();
711 
712  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
714  if (failed(warpMatrixInfo)) {
715  LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
716  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
717  }
718 
719  FailureOr<nvgpu::FragmentElementInfo> regInfo =
720  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
721  if (failed(regInfo)) {
722  LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
723  return rewriter.notifyMatchFailure(op, "not mma sync reg info");
724  }
725 
726  FailureOr<bool> transpose = isTransposed(op);
727  if (failed(transpose)) {
728  LLVM_DEBUG(DBGS() << "failed to determine the transpose\n");
729  return rewriter.notifyMatchFailure(
730  op, "Op should likely not be converted to a nvgpu.ldmatrix call.");
731  }
732 
733  FailureOr<nvgpu::LdMatrixParams> params =
734  nvgpu::getLdMatrixParams(*warpMatrixInfo, *transpose);
735 
736  if (failed(params)) {
737  LLVM_DEBUG(
738  DBGS()
739  << "failed to convert vector.transfer_read to ldmatrix. "
740  << "Op should likely not be converted to a nvgpu.ldmatrix call.\n");
741  return rewriter.notifyMatchFailure(
742  op, "failed to convert vector.transfer_read to ldmatrix; this op "
743  "likely should not be converted to a nvgpu.ldmatrix call.");
744  }
745 
746  // Adjust the load offset.
747  auto laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr);
748  FailureOr<AffineMap> offsets =
749  nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params);
750  if (failed(offsets)) {
751  LLVM_DEBUG(DBGS() << "no offsets\n");
752  return rewriter.notifyMatchFailure(op, "no offsets");
753  }
754 
755  VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
756 
757  SmallVector<Value, 4> indices;
758  getXferIndices<vector::TransferReadOp>(rewriter, op, *offsets, {laneId},
759  indices);
760 
761  nvgpu::LdMatrixOp newOp = rewriter.create<nvgpu::LdMatrixOp>(
762  loc, vectorType, op.getSource(), indices, *transpose, params->numTiles);
763  valueMapping[op] = newOp->getResult(0);
764  return success();
765 }
766 
767 static LogicalResult
768 createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
769  llvm::DenseMap<Value, Value> &valueMapping) {
770  OpBuilder::InsertionGuard g(rewriter);
771  rewriter.setInsertionPoint(op);
772 
773  Location loc = op.getLoc();
774  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
776  if (failed(warpMatrixInfo))
777  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
778  FailureOr<nvgpu::FragmentElementInfo> regInfo =
779  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
780  if (failed(regInfo)) {
781  return rewriter.notifyMatchFailure(
782  op, "Failed to deduce register fragment type during "
783  "conversion to distributed non-ldmatrix compatible load");
784  }
785 
786  Value laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr);
787 
788  // This is the individual element type.
789  Type loadedElType = regInfo->registerLLVMType;
790  VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
791 
792  Value fill = rewriter.create<arith::ConstantOp>(
793  op.getLoc(), vectorType.getElementType(),
794  rewriter.getZeroAttr(vectorType.getElementType()));
795  Value result =
796  rewriter.create<vector::SplatOp>(op.getLoc(), fill, vectorType);
797 
798  bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
799 
800  // If we are not transposing, then we can use vectorized loads. Otherwise, we
801  // must load each element individually.
802  if (!isTransposeLoad) {
803  if (!isa<VectorType>(loadedElType)) {
804  loadedElType = VectorType::get({1}, loadedElType);
805  }
806 
807  for (int i = 0; i < vectorType.getShape()[0]; i++) {
808  FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
809  rewriter, op.getLoc(), *warpMatrixInfo);
810  if (failed(coords))
811  return rewriter.notifyMatchFailure(op, "no coords");
812 
813  Value logicalValueId = rewriter.create<arith::ConstantOp>(
814  loc, rewriter.getIndexType(),
815  rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
816  SmallVector<Value, 4> newIndices;
817  getXferIndices<vector::TransferReadOp>(
818  rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
819 
820  Value el = rewriter.create<vector::LoadOp>(loc, loadedElType,
821  op.getSource(), newIndices);
822  result = rewriter.create<vector::InsertOp>(loc, el, result, i);
823  }
824  } else {
825  if (auto vecType = dyn_cast<VectorType>(loadedElType)) {
826  loadedElType = vecType.getElementType();
827  }
828  for (int i = 0; i < vectorType.getShape()[0]; i++) {
829  for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1];
830  innerIdx++) {
831 
832  Value logicalValueId = rewriter.create<arith::ConstantOp>(
833  loc, rewriter.getIndexType(),
834  rewriter.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx));
835  FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
836  rewriter, op.getLoc(), *warpMatrixInfo);
837  if (failed(coords))
838  return rewriter.notifyMatchFailure(op, "no coords");
839 
840  SmallVector<Value, 4> newIndices;
841  getXferIndices<vector::TransferReadOp>(
842  rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
843  Value el = rewriter.create<memref::LoadOp>(op.getLoc(), loadedElType,
844  op.getSource(), newIndices);
845  result = rewriter.create<vector::InsertOp>(
846  op.getLoc(), el, result, ArrayRef<int64_t>{i, innerIdx});
847  }
848  }
849  }
850 
851  valueMapping[op.getResult()] = result;
852  return success();
853 }
854 
855 /// Return true if this is a shared memory memref type.
856 static bool isSharedMemory(MemRefType type) {
857  auto addressSpace =
858  dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
859  return addressSpace &&
860  addressSpace.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
861 }
862 
863 /// Converts a `vector.transfer_read` operation directly to either a
864 /// `vector.load` or a `nvgpu.ldmatrix` operation. This function should only be
865 /// used when converting to `nvgpu.mma.sync` operations.
866 static LogicalResult
867 convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op,
868  llvm::DenseMap<Value, Value> &valueMapping) {
869  OpBuilder::InsertionGuard g(rewriter);
870  rewriter.setInsertionPoint(op);
871 
872  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
874  if (failed(warpMatrixInfo))
875  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
876 
877  bool isLdMatrixCompatible =
878  isSharedMemory(cast<MemRefType>(op.getSource().getType())) &&
879  nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128;
880 
881  VectorType vecTy = op.getVectorType();
882  int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth();
883 
884  // When we are transposing the B operand, ldmatrix will only work if we have
885  // at least 8 rows to read and the width to read for the transpose is 128
886  // bits.
887  if (!op.getPermutationMap().isMinorIdentity() &&
888  (bitWidth != 16 || vecTy.getDimSize(1) < 8 ||
889  vecTy.getDimSize(0) * bitWidth < 128))
890  isLdMatrixCompatible = false;
891 
892  if (!isLdMatrixCompatible)
893  return createNonLdMatrixLoads(rewriter, op, valueMapping);
894 
895  return creatLdMatrixCompatibleLoads(rewriter, op, valueMapping);
896 }
897 
898 static LogicalResult
899 convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
900  llvm::DenseMap<Value, Value> &valueMapping) {
901  OpBuilder::InsertionGuard g(rewriter);
902  rewriter.setInsertionPoint(op);
903 
904  Location loc = op->getLoc();
905  auto it = valueMapping.find(op.getVector());
906  if (it == valueMapping.end())
907  return rewriter.notifyMatchFailure(op, "no mapping");
908  Value matrix = it->second;
909 
910  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
912  if (failed(warpMatrixInfo))
913  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
914  FailureOr<nvgpu::FragmentElementInfo> regInfo =
915  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
916  if (failed(regInfo))
917  return rewriter.notifyMatchFailure(op, "not mma sync reg info");
918 
919  VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
920  Value laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr);
921 
922  for (unsigned i = 0; i < vectorType.getShape()[0]; i++) {
923  Value logicalValueId = rewriter.create<arith::ConstantOp>(
924  loc, rewriter.getIndexType(),
925  rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
926  FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
927  rewriter, op.getLoc(), *warpMatrixInfo);
928  if (failed(coords))
929  return rewriter.notifyMatchFailure(op, "no coords");
930 
931  Value el =
932  rewriter.create<vector::ExtractOp>(loc, matrix, ArrayRef<int64_t>{i});
933  SmallVector<Value, 4> newIndices;
934  getXferIndices<vector::TransferWriteOp>(
935  rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
936  rewriter.create<vector::StoreOp>(loc, el, op.getSource(), newIndices);
937  }
938 
939  LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
940  rewriter.eraseOp(op);
941  return success();
942 }
943 
944 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
945  SmallVectorImpl<int64_t> &results) {
946  for (auto attr : arrayAttr)
947  results.push_back(cast<IntegerAttr>(attr).getInt());
948 }
949 
950 static LogicalResult
952  vector::ExtractStridedSliceOp op,
953  llvm::DenseMap<Value, Value> &valueMapping) {
954  OpBuilder::InsertionGuard g(rewriter);
955  rewriter.setInsertionPoint(op);
956 
957  Location loc = op->getLoc();
958 
959  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
961  if (failed(warpMatrixInfo))
962  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
963 
964  FailureOr<nvgpu::FragmentElementInfo> mmaSyncFragmentInfo =
965  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
966  if (failed(mmaSyncFragmentInfo))
967  return rewriter.notifyMatchFailure(op, "no mmaSyncFragmentInfo");
968 
969  // Find the vector.transer_read whose result vector is being sliced.
970  auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>();
971  if (!transferReadOp)
972  return rewriter.notifyMatchFailure(op, "no transfer read");
973 
974  warpMatrixInfo = nvgpu::getWarpMatrixInfo(transferReadOp);
975  if (failed(warpMatrixInfo))
976  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
977 
978  FailureOr<nvgpu::FragmentElementInfo> ldFragmentInfo =
979  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
980  if (failed(ldFragmentInfo))
981  return rewriter.notifyMatchFailure(op, "no ldFragmentInfo");
982 
983  assert(
984  (mmaSyncFragmentInfo->elementsPerRegister ==
985  ldFragmentInfo->elementsPerRegister) &&
986  "Number of elements per register should be same for load and mma.sync");
987 
988  // Create vector.extract_strided_slice op for thread-owned fragments.
989  std::array<int64_t, 2> strides = {1,
990  1}; // stride for extract slice is always 1.
991  std::array<int64_t, 2> sliceShape = {
992  mmaSyncFragmentInfo->numRegistersPerFragment,
993  mmaSyncFragmentInfo->elementsPerRegister};
994  auto it = valueMapping.find(transferReadOp);
995  if (it == valueMapping.end())
996  return rewriter.notifyMatchFailure(op, "no mapping");
997  auto sourceVector = it->second;
998 
999  // offset and sizes at warp-level of onwership.
1000  SmallVector<int64_t> offsets;
1001  populateFromInt64AttrArray(op.getOffsets(), offsets);
1002 
1003  SmallVector<int64_t> sizes;
1004  populateFromInt64AttrArray(op.getSizes(), sizes);
1005  ArrayRef<int64_t> warpVectorShape = op.getSourceVectorType().getShape();
1006 
1007  // Compute offset in vector registers. Note that the mma.sync vector registers
1008  // are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector
1009  // registers can only be sliced along numberOfFragments, i.e., sliceOffset[0].
1010  std::array<int64_t, 2> sliceOffset = {0, 0};
1011 
1012  if (offsets[0] && offsets[1])
1013  return op->emitError() << "Slicing fragments in 2D is not supported. ";
1014  if (offsets[0])
1015  sliceOffset[0] = (warpVectorShape[0] / offsets[0]);
1016  else if (offsets[1])
1017  sliceOffset[0] = (warpVectorShape[1] / offsets[1]);
1018 
1019  Value newOp = rewriter.create<vector::ExtractStridedSliceOp>(
1020  loc, sourceVector, sliceOffset, sliceShape, strides);
1021 
1022  valueMapping[op] = newOp;
1023  return success();
1024 }
1025 
1026 static LogicalResult
1027 convertContractOp(RewriterBase &rewriter, vector::ContractionOp op,
1028  llvm::DenseMap<Value, Value> &valueMapping) {
1029  OpBuilder::InsertionGuard g(rewriter);
1030  rewriter.setInsertionPoint(op);
1031 
1032  auto itA = valueMapping.find(op.getLhs());
1033  auto itB = valueMapping.find(op.getRhs());
1034  auto itC = valueMapping.find(op.getAcc());
1035  if (itA == valueMapping.end() || itB == valueMapping.end() ||
1036  itC == valueMapping.end())
1037  return rewriter.notifyMatchFailure(op, "no mapping");
1038  Value opA = itA->second, opB = itB->second, opC = itC->second;
1039  Value matmul = rewriter.create<gpu::SubgroupMmaComputeOp>(
1040  op.getLoc(), opC.getType(), opA, opB, opC, /*a_transpose=*/UnitAttr(),
1041  /*b_transpose=*/UnitAttr());
1042  valueMapping[op.getResult()] = matmul;
1043  return success();
1044 }
1045 
1046 static LogicalResult
1047 convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op,
1048  llvm::DenseMap<Value, Value> &valueMapping) {
1049  OpBuilder::InsertionGuard g(rewriter);
1050  rewriter.setInsertionPoint(op);
1051 
1052  auto itA = valueMapping.find(op.getLhs());
1053  auto itB = valueMapping.find(op.getRhs());
1054  auto itC = valueMapping.find(op.getAcc());
1055  if (itA == valueMapping.end() || itB == valueMapping.end() ||
1056  itC == valueMapping.end())
1057  return rewriter.notifyMatchFailure(op, "no mapping");
1058  Value opA = itA->second, opB = itB->second, opC = itC->second;
1059  int64_t m = cast<VectorType>(op.getLhs().getType()).getShape()[0];
1060  int64_t n = cast<VectorType>(op.getRhs().getType()).getShape()[0];
1061  int64_t k = cast<VectorType>(op.getLhs().getType()).getShape()[1];
1062  Value matmul = rewriter.create<nvgpu::MmaSyncOp>(
1063  op.getLoc(), opA, opB, opC, rewriter.getI64ArrayAttr({m, n, k}));
1064  valueMapping[op.getResult()] = matmul;
1065  return success();
1066 }
1067 
1068 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
1069 static LogicalResult
1070 convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op,
1071  llvm::DenseMap<Value, Value> &valueMapping) {
1072  OpBuilder::InsertionGuard g(rewriter);
1073  rewriter.setInsertionPoint(op);
1074 
1075  assert(constantSupportsMMAMatrixType(op));
1076 
1077  auto splat =
1078  cast<SplatElementsAttr>(op.getValue()).getSplatValue<TypedAttr>();
1079  auto scalarConstant =
1080  rewriter.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat);
1081  const char *fragType = inferFragType(op);
1082  auto vecType = cast<VectorType>(op.getType());
1084  vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1085  auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(
1086  op.getLoc(), type, scalarConstant);
1087  valueMapping[op.getResult()] = matrix;
1088  return success();
1089 }
1090 
1091 /// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
1092 static LogicalResult
1093 convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op,
1094  llvm::DenseMap<Value, Value> &valueMapping) {
1095  OpBuilder::InsertionGuard g(rewriter);
1096  rewriter.setInsertionPoint(op);
1097 
1098  assert(broadcastSupportsMMAMatrixType(op));
1099 
1100  const char *fragType = inferFragType(op);
1101  auto vecType = op.getResultVectorType();
1103  vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1104  auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(
1105  op.getLoc(), type, op.getSource());
1106  valueMapping[op.getResult()] = matrix;
1107  return success();
1108 }
1109 
1110 // Replace ForOp with a new ForOp with extra operands. The YieldOp is not
1111 // updated and needs to be updated separately for the loop to be correct.
1112 static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
1113  scf::ForOp loop,
1114  ValueRange newInitArgs) {
1115  OpBuilder::InsertionGuard g(rewriter);
1116  rewriter.setInsertionPoint(loop);
1117 
1118  // Create a new loop before the existing one, with the extra operands.
1119  rewriter.setInsertionPoint(loop);
1120  auto operands = llvm::to_vector<4>(loop.getInitArgs());
1121  llvm::append_range(operands, newInitArgs);
1122  scf::ForOp newLoop = rewriter.create<scf::ForOp>(
1123  loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
1124  operands);
1125  rewriter.eraseBlock(newLoop.getBody());
1126 
1127  newLoop.getRegion().getBlocks().splice(
1128  newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks());
1129  for (Value operand : newInitArgs)
1130  newLoop.getBody()->addArgument(operand.getType(), operand.getLoc());
1131 
1132  for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
1133  loop.getNumResults())))
1134  rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
1135 
1136  LLVM_DEBUG(DBGS() << "newLoop now: " << newLoop << "\n");
1137  LLVM_DEBUG(DBGS() << "stripped scf.for: " << loop << "\n");
1138  LLVM_DEBUG(DBGS() << "erase: " << loop);
1139 
1140  rewriter.eraseOp(loop);
1141  return newLoop;
1142 }
1143 
1144 static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op,
1145  llvm::DenseMap<Value, Value> &valueMapping) {
1146  OpBuilder::InsertionGuard g(rewriter);
1147  rewriter.setInsertionPoint(op);
1148 
1149  SmallVector<Value> newOperands;
1151  for (const auto &operand : llvm::enumerate(op.getInitArgs())) {
1152  auto it = valueMapping.find(operand.value());
1153  if (it == valueMapping.end()) {
1154  LLVM_DEBUG(DBGS() << "no value mapping for: " << operand.value() << "\n");
1155  continue;
1156  }
1157  argMapping.push_back(std::make_pair(
1158  operand.index(), op.getInitArgs().size() + newOperands.size()));
1159  newOperands.push_back(it->second);
1160  }
1161 
1162  scf::ForOp newForOp = replaceForOpWithNewSignature(rewriter, op, newOperands);
1163  Block &loopBody = *newForOp.getBody();
1164  for (auto mapping : argMapping) {
1165  valueMapping[newForOp.getResult(mapping.first)] =
1166  newForOp.getResult(mapping.second);
1167  valueMapping[loopBody.getArgument(mapping.first +
1168  newForOp.getNumInductionVars())] =
1169  loopBody.getArgument(mapping.second + newForOp.getNumInductionVars());
1170  }
1171 
1172  LLVM_DEBUG(DBGS() << "scf.for to: " << newForOp << "\n");
1173  return success();
1174 }
1175 
1176 static LogicalResult
1177 convertYieldOp(RewriterBase &rewriter, scf::YieldOp op,
1178  llvm::DenseMap<Value, Value> &valueMapping) {
1179  OpBuilder::InsertionGuard g(rewriter);
1180  rewriter.setInsertionPoint(op);
1181 
1182  auto loop = cast<scf::ForOp>(op->getParentOp());
1183  auto yieldOperands = llvm::to_vector<4>(op.getOperands());
1184  for (const auto &operand : llvm::enumerate(op.getOperands())) {
1185  auto it = valueMapping.find(operand.value());
1186  if (it == valueMapping.end())
1187  continue;
1188  // Replace the yield of old value with the for op argument to make it easier
1189  // to remove the dead code.
1190  yieldOperands[operand.index()] = loop.getInitArgs()[operand.index()];
1191  yieldOperands.push_back(it->second);
1192  }
1193  rewriter.create<scf::YieldOp>(op.getLoc(), yieldOperands);
1194 
1195  LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
1196  rewriter.eraseOp(op);
1197  return success();
1198 }
1199 
1200 /// Convert an elementwise op to the equivalent elementwise op on MMA matrix.
1201 static LogicalResult
1203  gpu::MMAElementwiseOp opType,
1204  llvm::DenseMap<Value, Value> &valueMapping) {
1205  OpBuilder::InsertionGuard g(rewriter);
1206  rewriter.setInsertionPoint(op);
1207 
1208  SmallVector<Value> matrixOperands;
1209  for (Value operand : op->getOperands()) {
1210  auto it = valueMapping.find(operand);
1211  if (it == valueMapping.end())
1212  return rewriter.notifyMatchFailure(op, "no mapping");
1213  matrixOperands.push_back(it->second);
1214  }
1215  auto resultType = cast<gpu::MMAMatrixType>(matrixOperands[0].getType());
1216  if (opType == gpu::MMAElementwiseOp::EXTF) {
1217  // The floating point extension case has a different result type.
1218  auto vectorType = cast<VectorType>(op->getResultTypes()[0]);
1219  resultType = gpu::MMAMatrixType::get(resultType.getShape(),
1220  vectorType.getElementType(),
1221  resultType.getOperand());
1222  }
1223 
1224  Value newOp = rewriter.create<gpu::SubgroupMmaElementwiseOp>(
1225  op->getLoc(), resultType, matrixOperands, opType);
1226  valueMapping[op->getResult(0)] = newOp;
1227  return success();
1228 }
1229 
1231  bool useNvGpu) {
1232  if (!useNvGpu) {
1233  patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>(
1234  patterns.getContext());
1235  return;
1236  }
1238  patterns.add<CombineTransferReadOpTranspose>(patterns.getContext());
1239 }
1240 
1242  Operation *rootOp) {
1243  SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/false);
1244  llvm::DenseMap<Value, Value> valueMapping;
1245 
1246  auto globalRes = LogicalResult::success();
1247  for (Operation *op : ops) {
1248  LLVM_DEBUG(DBGS() << "Process op: " << *op << "\n");
1249  // Apparently callers do not want to early exit on failure here.
1250  auto res = LogicalResult::success();
1251  if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
1252  res = convertTransferReadOp(rewriter, transferRead, valueMapping);
1253  } else if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) {
1254  res = convertTransferWriteOp(rewriter, transferWrite, valueMapping);
1255  } else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) {
1256  res = convertContractOp(rewriter, contractOp, valueMapping);
1257  } else if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
1258  res = convertConstantOp(rewriter, constantOp, valueMapping);
1259  } else if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) {
1260  res = convertBroadcastOp(rewriter, broadcastOp, valueMapping);
1261  } else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
1262  res = convertForOp(rewriter, forOp, valueMapping);
1263  } else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
1264  res = convertYieldOp(rewriter, yieldOp, valueMapping);
1265  } else if (auto elementwiseType = convertElementwiseOpToMMA(op)) {
1266  res = convertElementwiseOp(rewriter, op, *elementwiseType, valueMapping);
1267  }
1268  if (failed(res))
1269  globalRes = failure();
1270  }
1271  return globalRes;
1272 }
1273 
1275  Operation *rootOp) {
1276  SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/true);
1277  llvm::DenseMap<Value, Value> valueMapping;
1278  for (Operation *op : ops) {
1280  .Case([&](vector::TransferReadOp transferReadOp) {
1281  return convertTransferReadToLoads(rewriter, transferReadOp,
1282  valueMapping);
1283  })
1284  .Case([&](vector::TransferWriteOp transferWriteOp) {
1285  return convertTransferWriteToStores(rewriter, transferWriteOp,
1286  valueMapping);
1287  })
1288  .Case([&](vector::ExtractStridedSliceOp extractStridedSliceOp) {
1289  return convertExtractStridedSlice(rewriter, extractStridedSliceOp,
1290  valueMapping);
1291  })
1292  .Case([&](vector::ContractionOp contractionOp) {
1293  return convertContractOpToMmaSync(rewriter, contractionOp,
1294  valueMapping);
1295  })
1296  .Case([&](scf::ForOp forOp) {
1297  return convertForOp(rewriter, forOp, valueMapping);
1298  })
1299  .Case([&](scf::YieldOp yieldOp) {
1300  return convertYieldOp(rewriter, yieldOp, valueMapping);
1301  })
1302  .Case([&](arith::ConstantOp constOp) {
1303  return convertConstantOpMmaSync(rewriter, constOp, valueMapping);
1304  })
1305  .Default([&](Operation *op) {
1306  return op->emitError() << "unhandled vector to mma type: " << *op;
1307  })
1308  .failed()) {
1309  return op->emitOpError()
1310  << "failed to convert op during vector-to-nvgpu conversion";
1311  }
1312  }
1313  return success();
1314 }
1315 
1316 namespace {
1317 
1318 struct ConvertVectorToGPUPass
1319  : public impl::ConvertVectorToGPUBase<ConvertVectorToGPUPass> {
1320 
1321  explicit ConvertVectorToGPUPass(bool useNvGpu_) {
1322  useNvGpu.setValue(useNvGpu_);
1323  }
1324 
1325  void runOnOperation() override {
1327  populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue());
1328  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
1329  return signalPassFailure();
1330 
1331  IRRewriter rewriter(&getContext());
1332  if (useNvGpu) {
1333  if (failed(
1334  convertVectorToNVVMCompatibleMMASync(rewriter, getOperation())))
1335  return signalPassFailure();
1336  return;
1337  }
1338  (void)convertVectorToMMAOps(rewriter, getOperation());
1339  }
1340 };
1341 
1342 } // namespace
1343 
1344 std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass(bool useNvGpu) {
1345  return std::make_unique<ConvertVectorToGPUPass>(useNvGpu);
1346 }
static MLIRContext * getContext(OpFoldResult val)
#define SUBI(lhs, rhs)
Definition: LoopEmitter.cpp:37
#define MULI(lhs, rhs)
Definition: LoopEmitter.cpp:38
#define ADDI(lhs, rhs)
Definition: LoopEmitter.cpp:35
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
static LogicalResult convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertContractOp(RewriterBase &rewriter, vector::ContractionOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool isSharedMemory(MemRefType type)
Return true if this is a shared memory memref type.
static VectorType getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo &regInfo)
Returns the vector type which represents a matrix fragment.
static const char * inferFragType(Operation *op)
static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp)
static bool supportsMMaMatrixType(Operation *op, bool useNvGpu)
static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp)
Return true if the constant is a splat to a 2D vector so that it can be converted to a MMA constant m...
static bool contractSupportsMMAMatrixType(vector::ContractionOp contract, bool useNvGpu)
Definition: VectorToGPU.cpp:77
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
static bool isTransposeMatrixLoadMap(AffineMap permutationMap)
static SetVector< Operation * > getSliceContract(Operation *op, const BackwardSliceOptions &backwardSliceOptions, const ForwardSliceOptions &forwardSliceOptions)
Return an unsorted slice handling scf.for region differently than getSlice.
static LogicalResult convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp)
Return true if this integer extend op can be folded into a contract op.
static LogicalResult convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
Converts a vector.transfer_read operation directly to either a vector.load or a nvgpu....
static LogicalResult convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
static LogicalResult creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
static FailureOr< bool > isTransposed(vector::TransferReadOp op)
Check if the loaded matrix operand requires transposed.
static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, scf::ForOp loop, ValueRange newInitArgs)
static bool transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp)
static std::optional< gpu::MMAElementwiseOp > convertElementwiseOpToMMA(Operation *op)
Return the MMA elementwise enum associated with op if it is supported.
static LogicalResult convertElementwiseOp(RewriterBase &rewriter, Operation *op, gpu::MMAElementwiseOp opType, llvm::DenseMap< Value, Value > &valueMapping)
Convert an elementwise op to the equivalent elementwise op on MMA matrix.
static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp)
static bool extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op)
Returns true if the extract strided slice op is supported with mma.sync path.
static LogicalResult convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static bool elementwiseSupportsMMAMatrixType(Operation *op)
Return true if the op is supported as elementwise op on MMAMatrix type.
static SetVector< Operation * > getOpToConvert(mlir::Operation *op, bool useNvGpu)
static LogicalResult convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static LogicalResult convertYieldOp(RewriterBase &rewriter, scf::YieldOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op, llvm::DenseMap< Value, Value > &valueMapping)
#define DBGS()
Definition: VectorToGPU.cpp:40
static std::optional< int64_t > getStaticallyKnownRowStride(ShapedType type)
static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp, AffineMap offsetMap, ArrayRef< Value > dimValues, SmallVector< Value, 4 > &indices)
For a vector TransferOpType xferOp, an empty indices vector, and an AffineMap representing offsets to...
Definition: VectorToGPU.cpp:57
static LogicalResult convertExtractStridedSlice(RewriterBase &rewriter, vector::ExtractStridedSliceOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp)
Return true if this is a broadcast from scalar to a 2D vector.
static LogicalResult createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
MLIRContext * getContext() const
Definition: AffineMap.cpp:343
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
Definition: AffineMap.cpp:155
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
Definition: AffineMap.cpp:394
unsigned getNumResults() const
Definition: AffineMap.cpp:402
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:411
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:264
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:556
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:312
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:295
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
UnitAttr getUnitAttr()
Definition: Builders.cpp:94
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:368
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:320
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:360
MLIRContext * getContext() const
Definition: Builders.h:56
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:277
IndexType getIndexType()
Definition: Builders.cpp:51
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:314
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:730
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:850
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:798
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:874
user_iterator user_begin()
Definition: Operation.h:870
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:602
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition: GPUDialect.h:131
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Definition: GPUDialect.cpp:125
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1168
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
int64_t inferTileWidthInBits(const WarpMatrixInfo &type)
Returns the number of bits in a single tile row.
Definition: MMAUtils.cpp:87
FailureOr< vector::ContractionOp > getUserContract(Operation *op)
Returns the first user of the op that is vector.contract.
Definition: MMAUtils.cpp:50
FailureOr< AffineMap > getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc, const WarpMatrixInfo &fragmentType)
Returns an AffineMap which maps a two dimensions representing (laneId, logicalValueId) and returns tw...
Definition: MMAUtils.cpp:169
FailureOr< WarpMatrixInfo > getWarpMatrixInfo(Operation *op)
If op is a vector.transfer_write, return the WarpMatrixInfo for the vector operand.
Definition: MMAUtils.cpp:58
FailureOr< AffineMap > getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc, const LdMatrixParams &params)
Returns an AffineMap which maps a single dimension representing the laneId to two results representin...
Definition: MMAUtils.cpp:234
FailureOr< LdMatrixParams > getLdMatrixParams(const WarpMatrixInfo &type, bool transpose)
Given type that contains info for a warp-matrix operand and whether or not the load is a transposed l...
Definition: MMAUtils.cpp:205
FailureOr< FragmentElementInfo > getMmaSyncRegisterType(const WarpMatrixInfo &type)
Returns a FragmentElementInfo struct describing the register types for the given matrix fragment type...
Definition: MMAUtils.cpp:100
bool canLowerToWarpMatrixOperation(vector::TransferReadOp op)
Returns whether the vector.transfer_read instruction can be interpreted as a warp-level cooperative m...
Definition: MMAUtils.cpp:272
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Definition: VectorOps.h:152
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Definition: VectorOps.h:147
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction wit...
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Definition: XeGPUOps.cpp:22
Include the generated interface declarations.
void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, bool useNvGpu=false)
Patterns to transform vector ops into a canonical form to convert to MMA matrix operations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
const FrozenRewritePatternSet & patterns
LogicalResult convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter, Operation *rootOp)
Convert vector ops ops nested under rootOp to vector and GPU operaitons compatible with the nvvm....
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:645
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::unique_ptr< Pass > createConvertVectorToGPUPass(bool useNvGpu=false)
Convert from vector to GPU ops.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:621
LogicalResult convertVectorToMMAOps(RewriterBase &rewriter, Operation *rootOp)
Convert vector ops to MMA matrix operations nested under rootOp.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
This trait tags element-wise ops on vectors or tensors.
TransitiveFilter filter
Definition: SliceAnalysis.h:29
Specifies information about the registers which compose a matrix fragment according to the PTX docume...
Definition: MMAUtils.h:52