MLIR  21.0.0git
VectorToGPU.cpp
Go to the documentation of this file.
1 //===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of vector operations to GPU dialect ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include <type_traits>
16 
30 #include "mlir/IR/Builders.h"
31 #include "mlir/IR/BuiltinOps.h"
32 #include "mlir/IR/Region.h"
33 #include "mlir/Pass/Pass.h"
35 #include "mlir/Transforms/Passes.h"
36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/TypeSwitch.h"
38 
39 #define DEBUG_TYPE "vector-to-gpu"
40 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
41 #define DBGSNL() (llvm::dbgs() << "\n")
42 
43 namespace mlir {
44 #define GEN_PASS_DEF_CONVERTVECTORTOGPU
45 #include "mlir/Conversion/Passes.h.inc"
46 } // namespace mlir
47 
48 using namespace mlir;
49 
50 /// For a vector TransferOpType `xferOp`, an empty `indices` vector, and an
51 /// AffineMap representing offsets to apply to indices, the function fills
52 /// `indices` with the original indices plus the offsets. The offsets are
53 /// applied by taking into account the permutation map of the transfer op. If
54 /// the `offsetMap` has dimension placeholders, those should be provided in
55 /// `dimValues`.
56 template <typename TransferOpType>
57 static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp,
58  AffineMap offsetMap, ArrayRef<Value> dimValues,
59  SmallVector<Value, 4> &indices) {
60  indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end());
61  Location loc = xferOp.getLoc();
62  unsigned offsetsIdx = 0;
63  for (auto expr : xferOp.getPermutationMap().getResults()) {
64  if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
65  Value prevIdx = indices[dim.getPosition()];
66  SmallVector<OpFoldResult, 3> dims(dimValues);
67  dims.push_back(prevIdx);
68  AffineExpr d0 = rewriter.getAffineDimExpr(offsetMap.getNumDims());
69  indices[dim.getPosition()] = affine::makeComposedAffineApply(
70  rewriter, loc, d0 + offsetMap.getResult(offsetsIdx++), dims);
71  continue;
72  }
73  }
74 }
75 
76 // Return true if the contract op can be convert to MMA matmul.
77 static bool contractSupportsMMAMatrixType(vector::ContractionOp contract,
78  bool useNvGpu) {
79  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
80  auto infer = [&](MapList m) {
81  return AffineMap::inferFromExprList(m, contract.getContext());
82  };
83  AffineExpr m, n, k;
84  bindDims(contract.getContext(), m, n, k);
85  auto iteratorTypes = contract.getIteratorTypes().getValue();
86  if (!(vector::isParallelIterator(iteratorTypes[0]) &&
87  vector::isParallelIterator(iteratorTypes[1]) &&
88  vector::isReductionIterator(iteratorTypes[2])))
89  return false;
90 
91  // The contract needs to represent a matmul to be able to convert to
92  // MMAMatrix matmul.
93  if (!useNvGpu &&
94  contract.getIndexingMapsArray() != infer({{m, k}, {k, n}, {m, n}}))
95  return false;
96  if (useNvGpu &&
97  contract.getIndexingMapsArray() != infer({{m, k}, {n, k}, {m, n}}))
98  return false;
99 
100  return true;
101 }
102 
103 // Return true if the given map represents a transposed matrix load,
104 // i.e. (d0, d1, ...) -> (dn-1, dn-2).
105 static bool isTransposeMatrixLoadMap(AffineMap permutationMap) {
106  MLIRContext *ctx = permutationMap.getContext();
107  // Local OpBuilder is fine here, we just build attributes.
108  OpBuilder b(ctx);
109  auto nDim = permutationMap.getNumDims();
110  AffineExpr zero = b.getAffineConstantExpr(0);
111  if (nDim < 2) {
112  // Support transposed+broadcasted cases: affine_map<(d0) -> (d0, 0)>.
113  AffineExpr dim0 = b.getAffineDimExpr(0);
114  return permutationMap == AffineMap::get(1, 0, {dim0, zero}, ctx);
115  }
116 
117  AffineExpr innerDim = b.getAffineDimExpr(nDim - 1);
118  AffineExpr outerDim = b.getAffineDimExpr(nDim - 2);
119  // Support both transposed and transposed+broadcasted cases.
120  return permutationMap == AffineMap::get(nDim, 0, {innerDim, outerDim}, ctx) ||
121  permutationMap == AffineMap::get(nDim, 0, {innerDim, zero}, ctx);
122 }
123 
124 // Return the stide for the second-to-last dimension of |type| if it is a memref
125 // and has a constant stride.
126 static std::optional<int64_t> getStaticallyKnownRowStride(ShapedType type) {
127  auto memrefType = dyn_cast<MemRefType>(type);
128  if (!memrefType)
129  return false;
130  // If the memref is 0 or 1D the horizontal stride is 0.
131  if (memrefType.getRank() < 2)
132  return 0;
133  int64_t offset = 0;
134  SmallVector<int64_t, 2> strides;
135  if (failed(memrefType.getStridesAndOffset(strides, offset)) ||
136  strides.back() != 1)
137  return std::nullopt;
138  int64_t stride = strides[strides.size() - 2];
139  if (stride == ShapedType::kDynamic)
140  return std::nullopt;
141  return stride;
142 }
143 
144 // Return true if the transfer op can be converted to a MMA matrix load.
145 static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
146  if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
147  readOp.getVectorType().getRank() != 2)
148  return false;
149  if (!getStaticallyKnownRowStride(readOp.getShapedType()))
150  return false;
151 
152  // Only allow integer types if the signedness can be inferred.
153  if (readOp.getVectorType().getElementType().isInteger(8))
154  if (!readOp->hasOneUse() || (!isa<arith::ExtSIOp>(*readOp->user_begin()) &&
155  !isa<arith::ExtUIOp>(*readOp->user_begin())))
156  return false;
157 
158  AffineMap map = readOp.getPermutationMap();
159  MLIRContext *ctx = readOp.getContext();
160  AffineExpr innerDim = getAffineDimExpr(map.getNumDims() - 1, ctx);
161  AffineExpr zero = getAffineConstantExpr(0, ctx);
162  auto broadcastInnerDim =
163  AffineMap::get(map.getNumDims(), 0, {zero, innerDim}, ctx);
164  return map.isMinorIdentity() || map == broadcastInnerDim ||
166 }
167 
168 // Return true if the transfer op can be converted to a MMA matrix store.
169 static bool
170 transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
171  // TODO: support 0-d corner case.
172  if (writeOp.getTransferRank() == 0)
173  return false;
174 
175  if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() ||
176  writeOp.getVectorType().getRank() != 2)
177  return false;
178  if (!getStaticallyKnownRowStride(writeOp.getShapedType()))
179  return false;
180  // TODO: Support transpose once it is added to GPU dialect ops.
181  if (!writeOp.getPermutationMap().isMinorIdentity())
182  return false;
183  return true;
184 }
185 
186 /// Return true if the constant is a splat to a 2D vector so that it can be
187 /// converted to a MMA constant matrix op.
188 static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) {
189  auto vecType = dyn_cast<VectorType>(constantOp.getType());
190  if (!vecType || vecType.getRank() != 2)
191  return false;
192  return isa<SplatElementsAttr>(constantOp.getValue());
193 }
194 
195 /// Return true if this is a broadcast from scalar to a 2D vector.
196 static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {
197  return broadcastOp.getResultVectorType().getRank() == 2;
198 }
199 
200 /// Return true if this integer extend op can be folded into a contract op.
201 template <typename ExtOpTy>
202 static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) {
203  auto transferReadOp =
204  extOp.getOperand().template getDefiningOp<vector::TransferReadOp>();
205  if (!transferReadOp)
206  return false;
207  return llvm::all_of(extOp->getUsers(), llvm::IsaPred<vector::ContractionOp>);
208 }
209 
210 static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp) { return true; }
211 
212 /// Return the MMA elementwise enum associated with `op` if it is supported.
213 /// Return `std::nullopt` otherwise.
214 static std::optional<gpu::MMAElementwiseOp>
216  if (isa<arith::AddFOp>(op))
217  return gpu::MMAElementwiseOp::ADDF;
218  if (isa<arith::MulFOp>(op))
219  return gpu::MMAElementwiseOp::MULF;
220  if (isa<arith::SubFOp>(op))
221  return gpu::MMAElementwiseOp::SUBF;
222  if (isa<arith::MaximumFOp>(op))
223  return gpu::MMAElementwiseOp::MAXF;
224  if (isa<arith::MinimumFOp>(op))
225  return gpu::MMAElementwiseOp::MINF;
226  if (isa<arith::DivFOp>(op))
227  return gpu::MMAElementwiseOp::DIVF;
228  if (isa<arith::AddIOp>(op))
230  if (isa<arith::MulIOp>(op))
232  if (isa<arith::SubIOp>(op))
234  if (isa<arith::DivSIOp>(op))
235  return gpu::MMAElementwiseOp::DIVS;
236  if (isa<arith::DivUIOp>(op))
237  return gpu::MMAElementwiseOp::DIVU;
238  if (isa<arith::NegFOp>(op))
239  return gpu::MMAElementwiseOp::NEGATEF;
240  if (isa<arith::ExtFOp>(op))
241  return gpu::MMAElementwiseOp::EXTF;
242  return std::nullopt;
243 }
244 
245 /// Return true if the op is supported as elementwise op on MMAMatrix type.
247  return convertElementwiseOpToMMA(op).has_value();
248 }
249 
250 /// Returns true if the extract strided slice op is supported with `mma.sync`
251 /// path.
252 static bool
253 extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op) {
254 
255  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
257  if (failed(warpMatrixInfo))
258  return false;
259 
260  FailureOr<vector::ContractionOp> contractOp = nvgpu::getUserContract(op);
261  if (failed(contractOp))
262  return false;
263 
264  // Handle vector.extract_strided_slice on registers containing
265  // matrixB and matrixC operands. vector.extract_strided_slice op
266  // is not supported on registers containing matrixA operands.
267  if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B)
268  return (cast<VectorType>(op->getResult(0).getType()) ==
269  cast<VectorType>((*contractOp).getRhs().getType()));
270  if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::C)
271  return (cast<VectorType>(op->getResult(0).getType()) ==
272  cast<VectorType>((*contractOp).getAcc().getType()));
273 
274  return false;
275 }
276 
277 static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
278  if (isa<scf::ForOp, scf::YieldOp>(op))
279  return true;
280  if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
281  return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferRead)
282  : transferReadSupportsMMAMatrixType(transferRead);
283  if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
284  return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferWrite)
285  : transferWriteSupportsMMAMatrixType(transferWrite);
286  if (auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op))
287  return useNvGpu &&
288  extractStridedSliceSupportsMMAMatrixType(extractStridedSlice);
289  if (auto contract = dyn_cast<vector::ContractionOp>(op))
290  return contractSupportsMMAMatrixType(contract, useNvGpu);
291  if (auto constant = dyn_cast<arith::ConstantOp>(op))
292  return constantSupportsMMAMatrixType(constant);
293  if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
295  if (auto signedExtend = dyn_cast<arith::ExtSIOp>(op))
296  return integerExtendSupportsMMAMatrixType<arith::ExtSIOp>(signedExtend);
297  if (auto unsignedExtend = dyn_cast<arith::ExtUIOp>(op))
298  return integerExtendSupportsMMAMatrixType<arith::ExtUIOp>(unsignedExtend);
299  if (auto fpExtend = dyn_cast<arith::ExtFOp>(op))
300  return fpExtendSupportsMMAMatrixType(fpExtend);
302 }
303 
304 /// Return an unsorted slice handling scf.for region differently than
305 /// `getSlice`. In scf.for we only want to include as part of the slice elements
306 /// that are part of the use/def chain.
309  const BackwardSliceOptions &backwardSliceOptions,
310  const ForwardSliceOptions &forwardSliceOptions) {
312  slice.insert(op);
313  unsigned currentIndex = 0;
314  SetVector<Operation *> backwardSlice;
315  SetVector<Operation *> forwardSlice;
316  while (currentIndex != slice.size()) {
317  auto *currentOp = (slice)[currentIndex];
318  // Compute and insert the backwardSlice starting from currentOp.
319  backwardSlice.clear();
320  LogicalResult result =
321  getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
322  assert(result.succeeded() && "expected a backward slice");
323  (void)result;
324  slice.insert_range(backwardSlice);
325 
326  // Compute and insert the forwardSlice starting from currentOp.
327  forwardSlice.clear();
328  // Special case for ForOp, we don't want to include the whole region but
329  // only the value using the region arguments.
330  // TODO: We should refine this to only care about the region arguments being
331  // converted to matrix type.
332  if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) {
333  for (Value forOpResult : forOp.getResults())
334  getForwardSlice(forOpResult, &forwardSlice, forwardSliceOptions);
335  for (BlockArgument &arg : forOp.getRegionIterArgs())
336  getForwardSlice(arg, &forwardSlice, forwardSliceOptions);
337  } else {
338  getForwardSlice(currentOp, &forwardSlice, forwardSliceOptions);
339  }
340  slice.insert_range(forwardSlice);
341  ++currentIndex;
342  }
343  return slice;
344 }
345 
346 // Analyze slice of operations based on convert op to figure out if the whole
347 // slice can be converted to MMA operations.
349  bool useNvGpu) {
350  auto hasVectorDest = [](Operation *op) {
351  return llvm::any_of(op->getResultTypes(), llvm::IsaPred<VectorType>);
352  };
353  BackwardSliceOptions backwardSliceOptions;
354  backwardSliceOptions.filter = hasVectorDest;
355 
356  auto hasVectorSrc = [](Operation *op) {
357  return llvm::any_of(op->getOperandTypes(), llvm::IsaPred<VectorType>);
358  };
359  ForwardSliceOptions forwardSliceOptions;
360  forwardSliceOptions.filter = hasVectorSrc;
361 
362  SetVector<Operation *> opToConvert;
363  op->walk([&](vector::ContractionOp contract) {
364  if (opToConvert.contains(contract.getOperation()))
365  return;
366  SetVector<Operation *> dependentOps =
367  getSliceContract(contract, backwardSliceOptions, forwardSliceOptions);
368  // If any instruction cannot use MMA matrix type drop the whole
369  // chain. MMA matrix are stored in an opaque type so they cannot be used
370  // by all operations.
371  if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) {
372  if (!supportsMMaMatrixType(op, useNvGpu)) {
373  LLVM_DEBUG(DBGS() << "cannot convert op: " << *op << "\n");
374  return true;
375  }
376  return false;
377  }))
378  return;
379 
380  opToConvert.insert_range(dependentOps);
381  });
382  // Sort the operations so that we can convert them in topological order.
383  return topologicalSort(opToConvert);
384 }
385 
386 namespace {
387 // Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted
388 // to MMA matmul.
389 struct PrepareContractToGPUMMA
390  : public OpRewritePattern<vector::ContractionOp> {
392 
393  LogicalResult matchAndRewrite(vector::ContractionOp op,
394  PatternRewriter &rewriter) const override {
395  Location loc = op.getLoc();
396  Value lhs = op.getLhs(), rhs = op.getRhs(), res = op.getAcc();
397 
398  // Set up the parallel/reduction structure in right form.
399  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
400  auto infer = [&](MapList m) {
401  return AffineMap::inferFromExprList(m, op.getContext());
402  };
403  AffineExpr m, n, k;
404  bindDims(rewriter.getContext(), m, n, k);
405  static constexpr std::array<int64_t, 2> perm = {1, 0};
406  auto iteratorTypes = op.getIteratorTypes().getValue();
407  SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
408  if (!(vector::isParallelIterator(iteratorTypes[0]) &&
409  vector::isParallelIterator(iteratorTypes[1]) &&
410  vector::isReductionIterator(iteratorTypes[2])))
411  return rewriter.notifyMatchFailure(op, "not a gemm contraction");
412  //
413  // Two outer parallel, one inner reduction (matmat flavor).
414  //
415  // This is the classical row-major matmul, nothing to do.
416  if (maps == infer({{m, k}, {k, n}, {m, n}}))
417  return rewriter.notifyMatchFailure(op, "contraction already prepared");
418  if (maps == infer({{m, k}, {n, k}, {m, n}})) {
419  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
420  } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
421  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
422  } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
423  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
424  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
425  } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
426  std::swap(rhs, lhs);
427  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
428  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
429  } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
430  std::swap(rhs, lhs);
431  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
432  } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
433  std::swap(lhs, rhs);
434  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
435  } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
436  std::swap(lhs, rhs);
437  } else {
438  // TODO: llvm_unreachable ?
439  return rewriter.notifyMatchFailure(op, "unexpected contraction case");
440  }
441  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
442  op, lhs, rhs, res,
443  rewriter.getAffineMapArrayAttr(infer({{m, k}, {k, n}, {m, n}})),
444  op.getIteratorTypes());
445  return success();
446  }
447 };
448 
449 // Fold transpose op into the transfer read op. NVGPU mma.sync op only supports
450 // row-, column-, and row-major layout for matrixA, matrixB, and matrixC,
451 // respectively. We can fold the transpose operation when loading the data from
452 // Shared Memory to registers.
453 struct CombineTransferReadOpTranspose final
454  : public OpRewritePattern<vector::TransposeOp> {
456 
457  LogicalResult matchAndRewrite(vector::TransposeOp op,
458  PatternRewriter &rewriter) const override {
459  // Look through integer extend ops.
460  Value source = op.getVector();
461  Type resultType = op.getType();
462  Operation *extOp;
463  if ((extOp = source.getDefiningOp<arith::ExtSIOp>()) ||
464  (extOp = source.getDefiningOp<arith::ExtUIOp>()) ||
465  (extOp = source.getDefiningOp<arith::ExtFOp>())) {
466  source = extOp->getOperand(0);
467  resultType =
468  VectorType::get(cast<VectorType>(resultType).getShape(),
469  cast<VectorType>(source.getType()).getElementType());
470  }
471 
472  auto transferReadOp = source.getDefiningOp<vector::TransferReadOp>();
473  if (!transferReadOp)
474  return rewriter.notifyMatchFailure(op, "no transfer read");
475 
476  // TODO: support 0-d corner case.
477  if (transferReadOp.getTransferRank() == 0)
478  return rewriter.notifyMatchFailure(op, "0-D transfer read");
479 
480  if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim())
481  return rewriter.notifyMatchFailure(op, "not inbounds transfer read");
482 
483  AffineMap permutationMap =
484  AffineMap::getPermutationMap(op.getPermutation(), op.getContext());
485  AffineMap newMap =
486  permutationMap.compose(transferReadOp.getPermutationMap());
487 
488  auto loc = op.getLoc();
489  Value result =
490  rewriter
491  .create<vector::TransferReadOp>(
492  loc, resultType, transferReadOp.getBase(),
493  transferReadOp.getIndices(), AffineMapAttr::get(newMap),
494  transferReadOp.getPadding(), transferReadOp.getMask(),
495  transferReadOp.getInBoundsAttr())
496  .getResult();
497 
498  // Fuse through the integer extend op.
499  if (extOp) {
500  if (isa<arith::ExtSIOp>(extOp))
501  result = rewriter.create<arith::ExtSIOp>(loc, op.getType(), result)
502  .getResult();
503  else if (isa<arith::ExtUIOp>(extOp))
504  result = rewriter.create<arith::ExtUIOp>(loc, op.getType(), result)
505  .getResult();
506  else
507  result = rewriter.create<arith::ExtFOp>(loc, op.getType(), result)
508  .getResult();
509  }
510 
511  rewriter.replaceOp(op, result);
512  return success();
513  }
514 };
515 
516 } // namespace
517 
518 // MMA types have different layout based on how they are used in matmul ops.
519 // Figure the right layout to use by looking at op uses.
520 // TODO: Change the GPU dialect to abstract the layout at the this level and
521 // only care about it during lowering to NVVM.
522 static const char *inferFragType(Operation *op) {
523  // We can have arith.ext ops before reaching contract ops. See through them
524  // and other kinds of elementwise ops.
525  if (op->hasOneUse()) {
526  Operation *userOp = *op->user_begin();
527  if (userOp->hasTrait<OpTrait::Elementwise>())
528  return inferFragType(userOp);
529  }
530 
531  for (Operation *users : op->getUsers()) {
532  auto contract = dyn_cast<vector::ContractionOp>(users);
533  if (!contract)
534  continue;
535  assert(op->getNumResults() == 1);
536  if (contract.getLhs() == op->getResult(0))
537  return "AOp";
538  if (contract.getRhs() == op->getResult(0))
539  return "BOp";
540  }
541  return "COp";
542 }
543 
544 static LogicalResult
545 convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
546  llvm::DenseMap<Value, Value> &valueMapping) {
547  OpBuilder::InsertionGuard g(rewriter);
548  rewriter.setInsertionPoint(op);
549 
550  assert(op.getTransferRank() > 0 && "unexpected 0-d transfer");
552  "expected convertible operation");
553 
554  std::optional<int64_t> stride =
555  getStaticallyKnownRowStride(op.getShapedType());
556  if (!stride.has_value()) {
557  LLVM_DEBUG(DBGS() << "no stride\n");
558  return rewriter.notifyMatchFailure(op, "no stride");
559  }
560 
561  AffineMap map = op.getPermutationMap();
562  bool isTranspose = isTransposeMatrixLoadMap(map);
563 
564  // Handle broadcast by setting the stride to 0.
565  if (auto cstExpr = dyn_cast<AffineConstantExpr>(map.getResult(isTranspose))) {
566  assert(cstExpr.getValue() == 0);
567  stride = 0;
568  }
569 
570  Value mappingResult = op.getResult();
571  auto elType = op.getVectorType().getElementType();
572  const char *fragType = inferFragType(op);
573  if (op->hasOneUse()) {
574  auto *user = *op->user_begin();
575  // Infer the signedness of the mma type from the integer extend.
576  if (isa<arith::ExtSIOp, arith::ExtUIOp>(user)) {
577  elType = IntegerType::get(
578  op.getContext(), cast<IntegerType>(elType).getWidth(),
579  isa<arith::ExtSIOp>(user) ? IntegerType::Signed
580  : IntegerType::Unsigned);
581  mappingResult = user->getResult(0);
582  }
583  }
584  gpu::MMAMatrixType type =
585  gpu::MMAMatrixType::get(op.getVectorType().getShape(), elType, fragType);
586  Value load = rewriter.create<gpu::SubgroupMmaLoadMatrixOp>(
587  op.getLoc(), type, op.getBase(), op.getIndices(),
588  rewriter.getIndexAttr(*stride),
589  isTranspose ? rewriter.getUnitAttr() : UnitAttr());
590  valueMapping[mappingResult] = load;
591 
592  LLVM_DEBUG(DBGS() << "transfer read to: " << load << "\n");
593  return success();
594 }
595 
596 static LogicalResult
597 convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
598  llvm::DenseMap<Value, Value> &valueMapping) {
599  OpBuilder::InsertionGuard g(rewriter);
600  rewriter.setInsertionPoint(op);
601 
603  std::optional<int64_t> stride =
604  getStaticallyKnownRowStride(op.getShapedType());
605  if (!stride.has_value()) {
606  LLVM_DEBUG(DBGS() << "no stride\n");
607  return rewriter.notifyMatchFailure(op, "no stride");
608  }
609 
610  auto it = valueMapping.find(op.getVector());
611  if (it == valueMapping.end()) {
612  LLVM_DEBUG(DBGS() << "no mapping\n");
613  return rewriter.notifyMatchFailure(op, "no mapping");
614  }
615 
616  Value matrix = it->second;
617  auto store = rewriter.create<gpu::SubgroupMmaStoreMatrixOp>(
618  op.getLoc(), matrix, op.getBase(), op.getIndices(),
619  rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr());
620  (void)store;
621 
622  LLVM_DEBUG(DBGS() << "transfer write to: " << store << "\n");
623 
624  LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
625  rewriter.eraseOp(op);
626  return success();
627 }
628 
629 /// Returns the vector type which represents a matrix fragment.
630 static VectorType
633  regInfo.elementsPerRegister};
634  Type elType = regInfo.registerLLVMType;
635  if (auto vecType = dyn_cast<VectorType>(elType))
636  elType = vecType.getElementType();
637  return VectorType::get(shape, elType);
638 }
639 
640 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
641 static LogicalResult
642 convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op,
643  llvm::DenseMap<Value, Value> &valueMapping) {
644  OpBuilder::InsertionGuard g(rewriter);
645  rewriter.setInsertionPoint(op);
646 
647  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
649  if (failed(warpMatrixInfo)) {
650  LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
651  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
652  }
653 
654  FailureOr<nvgpu::FragmentElementInfo> regInfo =
655  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
656  if (failed(regInfo)) {
657  LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
658  return rewriter.notifyMatchFailure(op, "not mma sync reg info");
659  }
660 
661  VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
662  auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
663  if (!dense) {
664  LLVM_DEBUG(DBGS() << "not a splat\n");
665  return rewriter.notifyMatchFailure(op, "not a splat");
666  }
667 
668  Value result = rewriter.create<arith::ConstantOp>(
669  op.getLoc(), vectorType,
670  DenseElementsAttr::get(vectorType, dense.getSplatValue<Attribute>()));
671  valueMapping[op.getResult()] = result;
672  return success();
673 }
674 
675 /// Check if the loaded matrix operand requires transposed.
676 /// Transposed Map Example:
677 /// Example 1 : (..., d0, d1) -> (d1 * 1, d0 * 2)
678 /// Example 2 : (d0, d1, d2, d3) -> (d3, d2)
679 /// The code below checks if the output 2D is transposed using a generalized
680 /// version : (d0, d1, dn, ..., dm, ...) -> (dm, dn)
681 /// Returns : true; if m > n, false o.w.
682 static FailureOr<bool> isTransposed(vector::TransferReadOp op) {
684 
685  if (map.getNumResults() != 2) {
686  LLVM_DEBUG(DBGS() << "Failed because the result of `vector.transfer_read` "
687  "is not a 2d operand\n");
688  return failure();
689  }
690 
691  // Output 2D matrix dimensions in the order of d0, d1.
692  mlir::AffineExpr dM = map.getResult(0);
693  mlir::AffineExpr dN = map.getResult(1);
694 
695  // Find the position of these expressions in the input.
696  auto exprM = dyn_cast<AffineDimExpr>(dM);
697  auto exprN = dyn_cast<AffineDimExpr>(dN);
698 
699  if (!exprM || !exprN) {
700  LLVM_DEBUG(DBGS() << "Failed because expressions are not affine dim "
701  "expressions, then transpose cannot be determined.\n");
702  return failure();
703  }
704 
705  return exprM.getPosition() > exprN.getPosition();
706 }
707 
708 static LogicalResult
709 creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
710  llvm::DenseMap<Value, Value> &valueMapping) {
711  OpBuilder::InsertionGuard g(rewriter);
712  rewriter.setInsertionPoint(op);
713  Location loc = op->getLoc();
714 
715  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
717  if (failed(warpMatrixInfo)) {
718  LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
719  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
720  }
721 
722  FailureOr<nvgpu::FragmentElementInfo> regInfo =
723  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
724  if (failed(regInfo)) {
725  LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
726  return rewriter.notifyMatchFailure(op, "not mma sync reg info");
727  }
728 
729  FailureOr<bool> transpose = isTransposed(op);
730  if (failed(transpose)) {
731  LLVM_DEBUG(DBGS() << "failed to determine the transpose\n");
732  return rewriter.notifyMatchFailure(
733  op, "Op should likely not be converted to a nvgpu.ldmatrix call.");
734  }
735 
736  FailureOr<nvgpu::LdMatrixParams> params =
737  nvgpu::getLdMatrixParams(*warpMatrixInfo, *transpose);
738 
739  if (failed(params)) {
740  LLVM_DEBUG(
741  DBGS()
742  << "failed to convert vector.transfer_read to ldmatrix. "
743  << "Op should likely not be converted to a nvgpu.ldmatrix call.\n");
744  return rewriter.notifyMatchFailure(
745  op, "failed to convert vector.transfer_read to ldmatrix; this op "
746  "likely should not be converted to a nvgpu.ldmatrix call.");
747  }
748 
749  // Adjust the load offset.
750  auto laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr);
751  FailureOr<AffineMap> offsets =
752  nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params);
753  if (failed(offsets)) {
754  LLVM_DEBUG(DBGS() << "no offsets\n");
755  return rewriter.notifyMatchFailure(op, "no offsets");
756  }
757 
758  VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
759 
760  SmallVector<Value, 4> indices;
761  getXferIndices<vector::TransferReadOp>(rewriter, op, *offsets, {laneId},
762  indices);
763 
764  nvgpu::LdMatrixOp newOp = rewriter.create<nvgpu::LdMatrixOp>(
765  loc, vectorType, op.getBase(), indices, *transpose, params->numTiles);
766  valueMapping[op] = newOp->getResult(0);
767  return success();
768 }
769 
770 static LogicalResult
771 createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
772  llvm::DenseMap<Value, Value> &valueMapping) {
773  OpBuilder::InsertionGuard g(rewriter);
774  rewriter.setInsertionPoint(op);
775 
776  Location loc = op.getLoc();
777  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
779  if (failed(warpMatrixInfo))
780  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
781  FailureOr<nvgpu::FragmentElementInfo> regInfo =
782  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
783  if (failed(regInfo)) {
784  return rewriter.notifyMatchFailure(
785  op, "Failed to deduce register fragment type during "
786  "conversion to distributed non-ldmatrix compatible load");
787  }
788 
789  Value laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr);
790 
791  // This is the individual element type.
792  Type loadedElType = regInfo->registerLLVMType;
793  VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
794 
795  Value fill = rewriter.create<arith::ConstantOp>(
796  op.getLoc(), vectorType.getElementType(),
797  rewriter.getZeroAttr(vectorType.getElementType()));
798  Value result =
799  rewriter.create<vector::SplatOp>(op.getLoc(), fill, vectorType);
800 
801  bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
802 
803  // If we are not transposing, then we can use vectorized loads. Otherwise, we
804  // must load each element individually.
805  if (!isTransposeLoad) {
806  if (!isa<VectorType>(loadedElType)) {
807  loadedElType = VectorType::get({1}, loadedElType);
808  }
809 
810  for (int i = 0; i < vectorType.getShape()[0]; i++) {
811  FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
812  rewriter, op.getLoc(), *warpMatrixInfo);
813  if (failed(coords))
814  return rewriter.notifyMatchFailure(op, "no coords");
815 
816  Value logicalValueId = rewriter.create<arith::ConstantOp>(
817  loc, rewriter.getIndexType(),
818  rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
819  SmallVector<Value, 4> newIndices;
820  getXferIndices<vector::TransferReadOp>(
821  rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
822 
823  Value el = rewriter.create<vector::LoadOp>(loc, loadedElType,
824  op.getBase(), newIndices);
825  result = rewriter.create<vector::InsertOp>(loc, el, result, i);
826  }
827  } else {
828  if (auto vecType = dyn_cast<VectorType>(loadedElType)) {
829  loadedElType = vecType.getElementType();
830  }
831  for (int i = 0; i < vectorType.getShape()[0]; i++) {
832  for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1];
833  innerIdx++) {
834 
835  Value logicalValueId = rewriter.create<arith::ConstantOp>(
836  loc, rewriter.getIndexType(),
837  rewriter.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx));
838  FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
839  rewriter, op.getLoc(), *warpMatrixInfo);
840  if (failed(coords))
841  return rewriter.notifyMatchFailure(op, "no coords");
842 
843  SmallVector<Value, 4> newIndices;
844  getXferIndices<vector::TransferReadOp>(
845  rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
846  Value el = rewriter.create<memref::LoadOp>(op.getLoc(), loadedElType,
847  op.getBase(), newIndices);
848  result = rewriter.create<vector::InsertOp>(
849  op.getLoc(), el, result, ArrayRef<int64_t>{i, innerIdx});
850  }
851  }
852  }
853 
854  valueMapping[op.getResult()] = result;
855  return success();
856 }
857 
858 /// Return true if this is a shared memory memref type.
859 static bool isSharedMemory(MemRefType type) {
860  auto addressSpace =
861  dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
862  return addressSpace &&
863  addressSpace.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
864 }
865 
866 /// Converts a `vector.transfer_read` operation directly to either a
867 /// `vector.load` or a `nvgpu.ldmatrix` operation. This function should only be
868 /// used when converting to `nvgpu.mma.sync` operations.
869 static LogicalResult
870 convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op,
871  llvm::DenseMap<Value, Value> &valueMapping) {
872  OpBuilder::InsertionGuard g(rewriter);
873  rewriter.setInsertionPoint(op);
874 
875  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
877  if (failed(warpMatrixInfo))
878  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
879 
880  bool isLdMatrixCompatible =
881  isSharedMemory(cast<MemRefType>(op.getBase().getType())) &&
882  nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128;
883 
884  VectorType vecTy = op.getVectorType();
885  int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth();
886 
887  // When we are transposing the B operand, ldmatrix will only work if we have
888  // at least 8 rows to read and the width to read for the transpose is 128
889  // bits.
890  if (!op.getPermutationMap().isMinorIdentity() &&
891  (bitWidth != 16 || vecTy.getDimSize(1) < 8 ||
892  vecTy.getDimSize(0) * bitWidth < 128))
893  isLdMatrixCompatible = false;
894 
895  if (!isLdMatrixCompatible)
896  return createNonLdMatrixLoads(rewriter, op, valueMapping);
897 
898  return creatLdMatrixCompatibleLoads(rewriter, op, valueMapping);
899 }
900 
901 static LogicalResult
902 convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
903  llvm::DenseMap<Value, Value> &valueMapping) {
904  OpBuilder::InsertionGuard g(rewriter);
905  rewriter.setInsertionPoint(op);
906 
907  Location loc = op->getLoc();
908  auto it = valueMapping.find(op.getVector());
909  if (it == valueMapping.end())
910  return rewriter.notifyMatchFailure(op, "no mapping");
911  Value matrix = it->second;
912 
913  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
915  if (failed(warpMatrixInfo))
916  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
917  FailureOr<nvgpu::FragmentElementInfo> regInfo =
918  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
919  if (failed(regInfo))
920  return rewriter.notifyMatchFailure(op, "not mma sync reg info");
921 
922  VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
923  Value laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr);
924 
925  for (unsigned i = 0; i < vectorType.getShape()[0]; i++) {
926  Value logicalValueId = rewriter.create<arith::ConstantOp>(
927  loc, rewriter.getIndexType(),
928  rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
929  FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
930  rewriter, op.getLoc(), *warpMatrixInfo);
931  if (failed(coords))
932  return rewriter.notifyMatchFailure(op, "no coords");
933 
934  Value el =
935  rewriter.create<vector::ExtractOp>(loc, matrix, ArrayRef<int64_t>{i});
936  SmallVector<Value, 4> newIndices;
937  getXferIndices<vector::TransferWriteOp>(
938  rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
939  rewriter.create<vector::StoreOp>(loc, el, op.getBase(), newIndices);
940  }
941 
942  LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
943  rewriter.eraseOp(op);
944  return success();
945 }
946 
947 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
948  SmallVectorImpl<int64_t> &results) {
949  for (auto attr : arrayAttr)
950  results.push_back(cast<IntegerAttr>(attr).getInt());
951 }
952 
953 static LogicalResult
955  vector::ExtractStridedSliceOp op,
956  llvm::DenseMap<Value, Value> &valueMapping) {
957  OpBuilder::InsertionGuard g(rewriter);
958  rewriter.setInsertionPoint(op);
959 
960  Location loc = op->getLoc();
961 
962  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
964  if (failed(warpMatrixInfo))
965  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
966 
967  FailureOr<nvgpu::FragmentElementInfo> mmaSyncFragmentInfo =
968  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
969  if (failed(mmaSyncFragmentInfo))
970  return rewriter.notifyMatchFailure(op, "no mmaSyncFragmentInfo");
971 
972  // Find the vector.transer_read whose result vector is being sliced.
973  auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>();
974  if (!transferReadOp)
975  return rewriter.notifyMatchFailure(op, "no transfer read");
976 
977  warpMatrixInfo = nvgpu::getWarpMatrixInfo(transferReadOp);
978  if (failed(warpMatrixInfo))
979  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
980 
981  FailureOr<nvgpu::FragmentElementInfo> ldFragmentInfo =
982  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
983  if (failed(ldFragmentInfo))
984  return rewriter.notifyMatchFailure(op, "no ldFragmentInfo");
985 
986  assert(
987  (mmaSyncFragmentInfo->elementsPerRegister ==
988  ldFragmentInfo->elementsPerRegister) &&
989  "Number of elements per register should be same for load and mma.sync");
990 
991  // Create vector.extract_strided_slice op for thread-owned fragments.
992  std::array<int64_t, 2> strides = {1,
993  1}; // stride for extract slice is always 1.
994  std::array<int64_t, 2> sliceShape = {
995  mmaSyncFragmentInfo->numRegistersPerFragment,
996  mmaSyncFragmentInfo->elementsPerRegister};
997  auto it = valueMapping.find(transferReadOp);
998  if (it == valueMapping.end())
999  return rewriter.notifyMatchFailure(op, "no mapping");
1000  auto sourceVector = it->second;
1001 
1002  // offset and sizes at warp-level of onwership.
1003  SmallVector<int64_t> offsets;
1004  populateFromInt64AttrArray(op.getOffsets(), offsets);
1005 
1006  SmallVector<int64_t> sizes;
1007  populateFromInt64AttrArray(op.getSizes(), sizes);
1008  ArrayRef<int64_t> warpVectorShape = op.getSourceVectorType().getShape();
1009 
1010  // Compute offset in vector registers. Note that the mma.sync vector registers
1011  // are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector
1012  // registers can only be sliced along numberOfFragments, i.e., sliceOffset[0].
1013  std::array<int64_t, 2> sliceOffset = {0, 0};
1014 
1015  if (offsets[0] && offsets[1])
1016  return op->emitError() << "Slicing fragments in 2D is not supported. ";
1017  if (offsets[0])
1018  sliceOffset[0] = (warpVectorShape[0] / offsets[0]);
1019  else if (offsets[1])
1020  sliceOffset[0] = (warpVectorShape[1] / offsets[1]);
1021 
1022  Value newOp = rewriter.create<vector::ExtractStridedSliceOp>(
1023  loc, sourceVector, sliceOffset, sliceShape, strides);
1024 
1025  valueMapping[op] = newOp;
1026  return success();
1027 }
1028 
1029 static LogicalResult
1030 convertContractOp(RewriterBase &rewriter, vector::ContractionOp op,
1031  llvm::DenseMap<Value, Value> &valueMapping) {
1032  OpBuilder::InsertionGuard g(rewriter);
1033  rewriter.setInsertionPoint(op);
1034 
1035  auto itA = valueMapping.find(op.getLhs());
1036  auto itB = valueMapping.find(op.getRhs());
1037  auto itC = valueMapping.find(op.getAcc());
1038  if (itA == valueMapping.end() || itB == valueMapping.end() ||
1039  itC == valueMapping.end())
1040  return rewriter.notifyMatchFailure(op, "no mapping");
1041  Value opA = itA->second, opB = itB->second, opC = itC->second;
1042  Value matmul = rewriter.create<gpu::SubgroupMmaComputeOp>(
1043  op.getLoc(), opC.getType(), opA, opB, opC, /*a_transpose=*/UnitAttr(),
1044  /*b_transpose=*/UnitAttr());
1045  valueMapping[op.getResult()] = matmul;
1046  return success();
1047 }
1048 
1049 static LogicalResult
1050 convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op,
1051  llvm::DenseMap<Value, Value> &valueMapping) {
1052  OpBuilder::InsertionGuard g(rewriter);
1053  rewriter.setInsertionPoint(op);
1054 
1055  auto itA = valueMapping.find(op.getLhs());
1056  auto itB = valueMapping.find(op.getRhs());
1057  auto itC = valueMapping.find(op.getAcc());
1058  if (itA == valueMapping.end() || itB == valueMapping.end() ||
1059  itC == valueMapping.end())
1060  return rewriter.notifyMatchFailure(op, "no mapping");
1061  Value opA = itA->second, opB = itB->second, opC = itC->second;
1062  int64_t m = cast<VectorType>(op.getLhs().getType()).getShape()[0];
1063  int64_t n = cast<VectorType>(op.getRhs().getType()).getShape()[0];
1064  int64_t k = cast<VectorType>(op.getLhs().getType()).getShape()[1];
1065  Value matmul = rewriter.create<nvgpu::MmaSyncOp>(
1066  op.getLoc(), opA, opB, opC, rewriter.getI64ArrayAttr({m, n, k}));
1067  valueMapping[op.getResult()] = matmul;
1068  return success();
1069 }
1070 
1071 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
1072 static LogicalResult
1073 convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op,
1074  llvm::DenseMap<Value, Value> &valueMapping) {
1075  OpBuilder::InsertionGuard g(rewriter);
1076  rewriter.setInsertionPoint(op);
1077 
1078  assert(constantSupportsMMAMatrixType(op));
1079 
1080  auto splat =
1081  cast<SplatElementsAttr>(op.getValue()).getSplatValue<TypedAttr>();
1082  auto scalarConstant =
1083  rewriter.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat);
1084  const char *fragType = inferFragType(op);
1085  auto vecType = cast<VectorType>(op.getType());
1087  vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1088  auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(
1089  op.getLoc(), type, scalarConstant);
1090  valueMapping[op.getResult()] = matrix;
1091  return success();
1092 }
1093 
1094 /// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
1095 static LogicalResult
1096 convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op,
1097  llvm::DenseMap<Value, Value> &valueMapping) {
1098  OpBuilder::InsertionGuard g(rewriter);
1099  rewriter.setInsertionPoint(op);
1100 
1101  assert(broadcastSupportsMMAMatrixType(op));
1102 
1103  const char *fragType = inferFragType(op);
1104  auto vecType = op.getResultVectorType();
1106  vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1107  auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(
1108  op.getLoc(), type, op.getSource());
1109  valueMapping[op.getResult()] = matrix;
1110  return success();
1111 }
1112 
1113 // Replace ForOp with a new ForOp with extra operands. The YieldOp is not
1114 // updated and needs to be updated separately for the loop to be correct.
1115 static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
1116  scf::ForOp loop,
1117  ValueRange newInitArgs) {
1118  OpBuilder::InsertionGuard g(rewriter);
1119  rewriter.setInsertionPoint(loop);
1120 
1121  // Create a new loop before the existing one, with the extra operands.
1122  rewriter.setInsertionPoint(loop);
1123  auto operands = llvm::to_vector<4>(loop.getInitArgs());
1124  llvm::append_range(operands, newInitArgs);
1125  scf::ForOp newLoop = rewriter.create<scf::ForOp>(
1126  loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
1127  operands);
1128  rewriter.eraseBlock(newLoop.getBody());
1129 
1130  newLoop.getRegion().getBlocks().splice(
1131  newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks());
1132  for (Value operand : newInitArgs)
1133  newLoop.getBody()->addArgument(operand.getType(), operand.getLoc());
1134 
1135  for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
1136  loop.getNumResults())))
1137  rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
1138 
1139  LLVM_DEBUG(DBGS() << "newLoop now: " << newLoop << "\n");
1140  LLVM_DEBUG(DBGS() << "stripped scf.for: " << loop << "\n");
1141  LLVM_DEBUG(DBGS() << "erase: " << loop);
1142 
1143  rewriter.eraseOp(loop);
1144  return newLoop;
1145 }
1146 
1147 static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op,
1148  llvm::DenseMap<Value, Value> &valueMapping) {
1149  OpBuilder::InsertionGuard g(rewriter);
1150  rewriter.setInsertionPoint(op);
1151 
1152  SmallVector<Value> newOperands;
1154  for (const auto &operand : llvm::enumerate(op.getInitArgs())) {
1155  auto it = valueMapping.find(operand.value());
1156  if (it == valueMapping.end()) {
1157  LLVM_DEBUG(DBGS() << "no value mapping for: " << operand.value() << "\n");
1158  continue;
1159  }
1160  argMapping.push_back(std::make_pair(
1161  operand.index(), op.getInitArgs().size() + newOperands.size()));
1162  newOperands.push_back(it->second);
1163  }
1164 
1165  scf::ForOp newForOp = replaceForOpWithNewSignature(rewriter, op, newOperands);
1166  Block &loopBody = *newForOp.getBody();
1167  for (auto mapping : argMapping) {
1168  valueMapping[newForOp.getResult(mapping.first)] =
1169  newForOp.getResult(mapping.second);
1170  valueMapping[loopBody.getArgument(mapping.first +
1171  newForOp.getNumInductionVars())] =
1172  loopBody.getArgument(mapping.second + newForOp.getNumInductionVars());
1173  }
1174 
1175  LLVM_DEBUG(DBGS() << "scf.for to: " << newForOp << "\n");
1176  return success();
1177 }
1178 
1179 static LogicalResult
1180 convertYieldOp(RewriterBase &rewriter, scf::YieldOp op,
1181  llvm::DenseMap<Value, Value> &valueMapping) {
1182  OpBuilder::InsertionGuard g(rewriter);
1183  rewriter.setInsertionPoint(op);
1184 
1185  auto loop = cast<scf::ForOp>(op->getParentOp());
1186  auto yieldOperands = llvm::to_vector<4>(op.getOperands());
1187  for (const auto &operand : llvm::enumerate(op.getOperands())) {
1188  auto it = valueMapping.find(operand.value());
1189  if (it == valueMapping.end())
1190  continue;
1191  // Replace the yield of old value with the for op argument to make it easier
1192  // to remove the dead code.
1193  yieldOperands[operand.index()] = loop.getInitArgs()[operand.index()];
1194  yieldOperands.push_back(it->second);
1195  }
1196  rewriter.create<scf::YieldOp>(op.getLoc(), yieldOperands);
1197 
1198  LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
1199  rewriter.eraseOp(op);
1200  return success();
1201 }
1202 
1203 /// Convert an elementwise op to the equivalent elementwise op on MMA matrix.
1204 static LogicalResult
1206  gpu::MMAElementwiseOp opType,
1207  llvm::DenseMap<Value, Value> &valueMapping) {
1208  OpBuilder::InsertionGuard g(rewriter);
1209  rewriter.setInsertionPoint(op);
1210 
1211  SmallVector<Value> matrixOperands;
1212  for (Value operand : op->getOperands()) {
1213  auto it = valueMapping.find(operand);
1214  if (it == valueMapping.end())
1215  return rewriter.notifyMatchFailure(op, "no mapping");
1216  matrixOperands.push_back(it->second);
1217  }
1218  auto resultType = cast<gpu::MMAMatrixType>(matrixOperands[0].getType());
1219  if (opType == gpu::MMAElementwiseOp::EXTF) {
1220  // The floating point extension case has a different result type.
1221  auto vectorType = cast<VectorType>(op->getResultTypes()[0]);
1222  resultType = gpu::MMAMatrixType::get(resultType.getShape(),
1223  vectorType.getElementType(),
1224  resultType.getOperand());
1225  }
1226 
1227  Value newOp = rewriter.create<gpu::SubgroupMmaElementwiseOp>(
1228  op->getLoc(), resultType, matrixOperands, opType);
1229  valueMapping[op->getResult(0)] = newOp;
1230  return success();
1231 }
1232 
1234  bool useNvGpu) {
1235  if (!useNvGpu) {
1236  patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>(
1237  patterns.getContext());
1238  return;
1239  }
1241  patterns.add<CombineTransferReadOpTranspose>(patterns.getContext());
1242 }
1243 
1245  Operation *rootOp) {
1246  SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/false);
1247  llvm::DenseMap<Value, Value> valueMapping;
1248 
1249  auto globalRes = LogicalResult::success();
1250  for (Operation *op : ops) {
1251  LLVM_DEBUG(DBGS() << "Process op: " << *op << "\n");
1252  // Apparently callers do not want to early exit on failure here.
1253  auto res = LogicalResult::success();
1254  if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
1255  res = convertTransferReadOp(rewriter, transferRead, valueMapping);
1256  } else if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) {
1257  res = convertTransferWriteOp(rewriter, transferWrite, valueMapping);
1258  } else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) {
1259  res = convertContractOp(rewriter, contractOp, valueMapping);
1260  } else if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
1261  res = convertConstantOp(rewriter, constantOp, valueMapping);
1262  } else if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) {
1263  res = convertBroadcastOp(rewriter, broadcastOp, valueMapping);
1264  } else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
1265  res = convertForOp(rewriter, forOp, valueMapping);
1266  } else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
1267  res = convertYieldOp(rewriter, yieldOp, valueMapping);
1268  } else if (auto elementwiseType = convertElementwiseOpToMMA(op)) {
1269  res = convertElementwiseOp(rewriter, op, *elementwiseType, valueMapping);
1270  }
1271  if (failed(res))
1272  globalRes = failure();
1273  }
1274  return globalRes;
1275 }
1276 
1278  Operation *rootOp) {
1279  SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/true);
1280  llvm::DenseMap<Value, Value> valueMapping;
1281  for (Operation *op : ops) {
1283  .Case([&](vector::TransferReadOp transferReadOp) {
1284  return convertTransferReadToLoads(rewriter, transferReadOp,
1285  valueMapping);
1286  })
1287  .Case([&](vector::TransferWriteOp transferWriteOp) {
1288  return convertTransferWriteToStores(rewriter, transferWriteOp,
1289  valueMapping);
1290  })
1291  .Case([&](vector::ExtractStridedSliceOp extractStridedSliceOp) {
1292  return convertExtractStridedSlice(rewriter, extractStridedSliceOp,
1293  valueMapping);
1294  })
1295  .Case([&](vector::ContractionOp contractionOp) {
1296  return convertContractOpToMmaSync(rewriter, contractionOp,
1297  valueMapping);
1298  })
1299  .Case([&](scf::ForOp forOp) {
1300  return convertForOp(rewriter, forOp, valueMapping);
1301  })
1302  .Case([&](scf::YieldOp yieldOp) {
1303  return convertYieldOp(rewriter, yieldOp, valueMapping);
1304  })
1305  .Case([&](arith::ConstantOp constOp) {
1306  return convertConstantOpMmaSync(rewriter, constOp, valueMapping);
1307  })
1308  .Default([&](Operation *op) {
1309  return op->emitError() << "unhandled vector to mma type: " << *op;
1310  })
1311  .failed()) {
1312  return op->emitOpError()
1313  << "failed to convert op during vector-to-nvgpu conversion";
1314  }
1315  }
1316  return success();
1317 }
1318 
1319 namespace {
1320 
1321 struct ConvertVectorToGPUPass
1322  : public impl::ConvertVectorToGPUBase<ConvertVectorToGPUPass> {
1323 
1324  explicit ConvertVectorToGPUPass(bool useNvGpu_) {
1325  useNvGpu.setValue(useNvGpu_);
1326  }
1327 
1328  void runOnOperation() override {
1330  populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue());
1331  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
1332  return signalPassFailure();
1333 
1334  IRRewriter rewriter(&getContext());
1335  if (useNvGpu) {
1336  if (failed(
1337  convertVectorToNVVMCompatibleMMASync(rewriter, getOperation())))
1338  return signalPassFailure();
1339  return;
1340  }
1341  (void)convertVectorToMMAOps(rewriter, getOperation());
1342  }
1343 };
1344 
1345 } // namespace
1346 
1347 std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass(bool useNvGpu) {
1348  return std::make_unique<ConvertVectorToGPUPass>(useNvGpu);
1349 }
static MLIRContext * getContext(OpFoldResult val)
#define SUBI(lhs, rhs)
Definition: LoopEmitter.cpp:37
#define MULI(lhs, rhs)
Definition: LoopEmitter.cpp:38
#define ADDI(lhs, rhs)
Definition: LoopEmitter.cpp:35
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
static LogicalResult convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertContractOp(RewriterBase &rewriter, vector::ContractionOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool isSharedMemory(MemRefType type)
Return true if this is a shared memory memref type.
static VectorType getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo &regInfo)
Returns the vector type which represents a matrix fragment.
static const char * inferFragType(Operation *op)
static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp)
static bool supportsMMaMatrixType(Operation *op, bool useNvGpu)
static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp)
Return true if the constant is a splat to a 2D vector so that it can be converted to a MMA constant m...
static bool contractSupportsMMAMatrixType(vector::ContractionOp contract, bool useNvGpu)
Definition: VectorToGPU.cpp:77
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
static bool isTransposeMatrixLoadMap(AffineMap permutationMap)
static SetVector< Operation * > getSliceContract(Operation *op, const BackwardSliceOptions &backwardSliceOptions, const ForwardSliceOptions &forwardSliceOptions)
Return an unsorted slice handling scf.for region differently than getSlice.
static LogicalResult convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp)
Return true if this integer extend op can be folded into a contract op.
static LogicalResult convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
Converts a vector.transfer_read operation directly to either a vector.load or a nvgpu....
static LogicalResult convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
static LogicalResult creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
static FailureOr< bool > isTransposed(vector::TransferReadOp op)
Check if the loaded matrix operand requires transposed.
static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, scf::ForOp loop, ValueRange newInitArgs)
static bool transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp)
static std::optional< gpu::MMAElementwiseOp > convertElementwiseOpToMMA(Operation *op)
Return the MMA elementwise enum associated with op if it is supported.
static LogicalResult convertElementwiseOp(RewriterBase &rewriter, Operation *op, gpu::MMAElementwiseOp opType, llvm::DenseMap< Value, Value > &valueMapping)
Convert an elementwise op to the equivalent elementwise op on MMA matrix.
static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp)
static bool extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op)
Returns true if the extract strided slice op is supported with mma.sync path.
static LogicalResult convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static bool elementwiseSupportsMMAMatrixType(Operation *op)
Return true if the op is supported as elementwise op on MMAMatrix type.
static SetVector< Operation * > getOpToConvert(mlir::Operation *op, bool useNvGpu)
static LogicalResult convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static LogicalResult convertYieldOp(RewriterBase &rewriter, scf::YieldOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op, llvm::DenseMap< Value, Value > &valueMapping)
#define DBGS()
Definition: VectorToGPU.cpp:40
static std::optional< int64_t > getStaticallyKnownRowStride(ShapedType type)
static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp, AffineMap offsetMap, ArrayRef< Value > dimValues, SmallVector< Value, 4 > &indices)
For a vector TransferOpType xferOp, an empty indices vector, and an AffineMap representing offsets to...
Definition: VectorToGPU.cpp:57
static LogicalResult convertExtractStridedSlice(RewriterBase &rewriter, vector::ExtractStridedSliceOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp)
Return true if this is a broadcast from scalar to a 2D vector.
static LogicalResult createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
MLIRContext * getContext() const
Definition: AffineMap.cpp:343
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
Definition: AffineMap.cpp:155
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
Definition: AffineMap.cpp:394
unsigned getNumResults() const
Definition: AffineMap.cpp:402
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:411
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:264
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:556
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:312
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
UnitAttr getUnitAttr()
Definition: Builders.cpp:94
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:368
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:320
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:360
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:277
IndexType getIndexType()
Definition: Builders.cpp:51
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:314
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:730
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:345
This class helps build Operations.
Definition: Builders.h:204
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:395
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:849
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:873
user_iterator user_begin()
Definition: Operation.h:869
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:602
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition: GPUDialect.h:131
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Definition: GPUDialect.cpp:125
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1175
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
int64_t inferTileWidthInBits(const WarpMatrixInfo &type)
Returns the number of bits in a single tile row.
Definition: MMAUtils.cpp:87
FailureOr< vector::ContractionOp > getUserContract(Operation *op)
Returns the first user of the op that is vector.contract.
Definition: MMAUtils.cpp:50
FailureOr< AffineMap > getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc, const WarpMatrixInfo &fragmentType)
Returns an AffineMap which maps a two dimensions representing (laneId, logicalValueId) and returns tw...
Definition: MMAUtils.cpp:169
FailureOr< WarpMatrixInfo > getWarpMatrixInfo(Operation *op)
If op is a vector.transfer_write, return the WarpMatrixInfo for the vector operand.
Definition: MMAUtils.cpp:58
FailureOr< AffineMap > getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc, const LdMatrixParams &params)
Returns an AffineMap which maps a single dimension representing the laneId to two results representin...
Definition: MMAUtils.cpp:234
FailureOr< LdMatrixParams > getLdMatrixParams(const WarpMatrixInfo &type, bool transpose)
Given type that contains info for a warp-matrix operand and whether or not the load is a transposed l...
Definition: MMAUtils.cpp:205
FailureOr< FragmentElementInfo > getMmaSyncRegisterType(const WarpMatrixInfo &type)
Returns a FragmentElementInfo struct describing the register types for the given matrix fragment type...
Definition: MMAUtils.cpp:100
bool canLowerToWarpMatrixOperation(vector::TransferReadOp op)
Returns whether the vector.transfer_read instruction can be interpreted as a warp-level cooperative m...
Definition: MMAUtils.cpp:272
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Definition: VectorOps.h:152
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Definition: VectorOps.h:147
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction wit...
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Definition: XeGPUOps.cpp:23
Include the generated interface declarations.
void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, bool useNvGpu=false)
Patterns to transform vector ops into a canonical form to convert to MMA matrix operations.
LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
LogicalResult convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter, Operation *rootOp)
Convert vector ops ops nested under rootOp to vector and GPU operaitons compatible with the nvvm....
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:645
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::unique_ptr< Pass > createConvertVectorToGPUPass(bool useNvGpu=false)
Convert from vector to GPU ops.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:621
LogicalResult convertVectorToMMAOps(RewriterBase &rewriter, Operation *rootOp)
Convert vector ops to MMA matrix operations nested under rootOp.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
This trait tags element-wise ops on vectors or tensors.
TransitiveFilter filter
Definition: SliceAnalysis.h:29
Specifies information about the registers which compose a matrix fragment according to the PTX docume...
Definition: MMAUtils.h:52