MLIR  20.0.0git
VectorToGPU.cpp
Go to the documentation of this file.
1 //===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of vector operations to GPU dialect ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include <type_traits>
16 
30 #include "mlir/IR/Builders.h"
31 #include "mlir/IR/BuiltinOps.h"
32 #include "mlir/IR/Region.h"
33 #include "mlir/Pass/Pass.h"
35 #include "mlir/Transforms/Passes.h"
36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/TypeSwitch.h"
38 
39 #define DEBUG_TYPE "vector-to-gpu"
40 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
41 #define DBGSNL() (llvm::dbgs() << "\n")
42 
43 namespace mlir {
44 #define GEN_PASS_DEF_CONVERTVECTORTOGPU
45 #include "mlir/Conversion/Passes.h.inc"
46 } // namespace mlir
47 
48 using namespace mlir;
49 
50 /// For a vector TransferOpType `xferOp`, an empty `indices` vector, and an
51 /// AffineMap representing offsets to apply to indices, the function fills
52 /// `indices` with the original indices plus the offsets. The offsets are
53 /// applied by taking into account the permutation map of the transfer op. If
54 /// the `offsetMap` has dimension placeholders, those should be provided in
55 /// `dimValues`.
56 template <typename TransferOpType>
57 static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp,
58  AffineMap offsetMap, ArrayRef<Value> dimValues,
59  SmallVector<Value, 4> &indices) {
60  indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end());
61  Location loc = xferOp.getLoc();
62  unsigned offsetsIdx = 0;
63  for (auto expr : xferOp.getPermutationMap().getResults()) {
64  if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
65  Value prevIdx = indices[dim.getPosition()];
66  SmallVector<OpFoldResult, 3> dims(dimValues);
67  dims.push_back(prevIdx);
68  AffineExpr d0 = rewriter.getAffineDimExpr(offsetMap.getNumDims());
69  indices[dim.getPosition()] = affine::makeComposedAffineApply(
70  rewriter, loc, d0 + offsetMap.getResult(offsetsIdx++), dims);
71  continue;
72  }
73  }
74 }
75 
76 // Return true if the contract op can be convert to MMA matmul.
77 static bool contractSupportsMMAMatrixType(vector::ContractionOp contract,
78  bool useNvGpu) {
79  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
80  auto infer = [&](MapList m) {
81  return AffineMap::inferFromExprList(m, contract.getContext());
82  };
83  AffineExpr m, n, k;
84  bindDims(contract.getContext(), m, n, k);
85  auto iteratorTypes = contract.getIteratorTypes().getValue();
86  if (!(vector::isParallelIterator(iteratorTypes[0]) &&
87  vector::isParallelIterator(iteratorTypes[1]) &&
88  vector::isReductionIterator(iteratorTypes[2])))
89  return false;
90 
91  // The contract needs to represent a matmul to be able to convert to
92  // MMAMatrix matmul.
93  if (!useNvGpu &&
94  contract.getIndexingMapsArray() != infer({{m, k}, {k, n}, {m, n}}))
95  return false;
96  if (useNvGpu &&
97  contract.getIndexingMapsArray() != infer({{m, k}, {n, k}, {m, n}}))
98  return false;
99 
100  return true;
101 }
102 
103 // Return true if the given map represents a transposed matrix load,
104 // i.e. (d0, d1, ...) -> (dn-1, dn-2).
105 static bool isTransposeMatrixLoadMap(AffineMap permutationMap) {
106  MLIRContext *ctx = permutationMap.getContext();
107  // Local OpBuilder is fine here, we just build attributes.
108  OpBuilder b(ctx);
109  auto nDim = permutationMap.getNumDims();
110  AffineExpr zero = b.getAffineConstantExpr(0);
111  if (nDim < 2) {
112  // Support transposed+broadcasted cases: affine_map<(d0) -> (d0, 0)>.
113  AffineExpr dim0 = b.getAffineDimExpr(0);
114  return permutationMap == AffineMap::get(1, 0, {dim0, zero}, ctx);
115  }
116 
117  AffineExpr innerDim = b.getAffineDimExpr(nDim - 1);
118  AffineExpr outerDim = b.getAffineDimExpr(nDim - 2);
119  // Support both transposed and transposed+broadcasted cases.
120  return permutationMap == AffineMap::get(nDim, 0, {innerDim, outerDim}, ctx) ||
121  permutationMap == AffineMap::get(nDim, 0, {innerDim, zero}, ctx);
122 }
123 
124 // Return the stide for the second-to-last dimension of |type| if it is a memref
125 // and has a constant stride.
126 static std::optional<int64_t> getStaticallyKnownRowStride(ShapedType type) {
127  auto memrefType = dyn_cast<MemRefType>(type);
128  if (!memrefType)
129  return false;
130  // If the memref is 0 or 1D the horizontal stride is 0.
131  if (memrefType.getRank() < 2)
132  return 0;
133  int64_t offset = 0;
134  SmallVector<int64_t, 2> strides;
135  if (failed(getStridesAndOffset(memrefType, strides, offset)) ||
136  strides.back() != 1)
137  return std::nullopt;
138  int64_t stride = strides[strides.size() - 2];
139  if (stride == ShapedType::kDynamic)
140  return std::nullopt;
141  return stride;
142 }
143 
144 // Return true if the transfer op can be converted to a MMA matrix load.
145 static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
146  if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
147  readOp.getVectorType().getRank() != 2)
148  return false;
149  if (!getStaticallyKnownRowStride(readOp.getShapedType()))
150  return false;
151 
152  // Only allow integer types if the signedness can be inferred.
153  if (readOp.getVectorType().getElementType().isInteger(8))
154  if (!readOp->hasOneUse() || (!isa<arith::ExtSIOp>(*readOp->user_begin()) &&
155  !isa<arith::ExtUIOp>(*readOp->user_begin())))
156  return false;
157 
158  AffineMap map = readOp.getPermutationMap();
159  MLIRContext *ctx = readOp.getContext();
160  AffineExpr innerDim = getAffineDimExpr(map.getNumDims() - 1, ctx);
161  AffineExpr zero = getAffineConstantExpr(0, ctx);
162  auto broadcastInnerDim =
163  AffineMap::get(map.getNumDims(), 0, {zero, innerDim}, ctx);
164  return map.isMinorIdentity() || map == broadcastInnerDim ||
166 }
167 
168 // Return true if the transfer op can be converted to a MMA matrix store.
169 static bool
170 transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
171  // TODO: support 0-d corner case.
172  if (writeOp.getTransferRank() == 0)
173  return false;
174 
175  if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() ||
176  writeOp.getVectorType().getRank() != 2)
177  return false;
178  if (!getStaticallyKnownRowStride(writeOp.getShapedType()))
179  return false;
180  // TODO: Support transpose once it is added to GPU dialect ops.
181  if (!writeOp.getPermutationMap().isMinorIdentity())
182  return false;
183  return true;
184 }
185 
186 /// Return true if the constant is a splat to a 2D vector so that it can be
187 /// converted to a MMA constant matrix op.
188 static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) {
189  auto vecType = dyn_cast<VectorType>(constantOp.getType());
190  if (!vecType || vecType.getRank() != 2)
191  return false;
192  return isa<SplatElementsAttr>(constantOp.getValue());
193 }
194 
195 /// Return true if this is a broadcast from scalar to a 2D vector.
196 static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {
197  return broadcastOp.getResultVectorType().getRank() == 2;
198 }
199 
200 /// Return true if this integer extend op can be folded into a contract op.
201 template <typename ExtOpTy>
202 static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) {
203  auto transferReadOp =
204  extOp.getOperand().template getDefiningOp<vector::TransferReadOp>();
205  if (!transferReadOp)
206  return false;
207  return llvm::all_of(extOp->getUsers(), llvm::IsaPred<vector::ContractionOp>);
208 }
209 
210 static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp) { return true; }
211 
212 /// Return the MMA elementwise enum associated with `op` if it is supported.
213 /// Return `std::nullopt` otherwise.
214 static std::optional<gpu::MMAElementwiseOp>
216  if (isa<arith::AddFOp>(op))
217  return gpu::MMAElementwiseOp::ADDF;
218  if (isa<arith::MulFOp>(op))
219  return gpu::MMAElementwiseOp::MULF;
220  if (isa<arith::SubFOp>(op))
221  return gpu::MMAElementwiseOp::SUBF;
222  if (isa<arith::MaximumFOp>(op))
223  return gpu::MMAElementwiseOp::MAXF;
224  if (isa<arith::MinimumFOp>(op))
225  return gpu::MMAElementwiseOp::MINF;
226  if (isa<arith::DivFOp>(op))
227  return gpu::MMAElementwiseOp::DIVF;
228  if (isa<arith::AddIOp>(op))
230  if (isa<arith::MulIOp>(op))
232  if (isa<arith::SubIOp>(op))
234  if (isa<arith::DivSIOp>(op))
235  return gpu::MMAElementwiseOp::DIVS;
236  if (isa<arith::DivUIOp>(op))
237  return gpu::MMAElementwiseOp::DIVU;
238  if (isa<arith::NegFOp>(op))
239  return gpu::MMAElementwiseOp::NEGATEF;
240  if (isa<arith::ExtFOp>(op))
241  return gpu::MMAElementwiseOp::EXTF;
242  return std::nullopt;
243 }
244 
245 /// Return true if the op is supported as elementwise op on MMAMatrix type.
247  return convertElementwiseOpToMMA(op).has_value();
248 }
249 
250 /// Returns true if the extract strided slice op is supported with `mma.sync`
251 /// path.
252 static bool
253 extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op) {
254 
255  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
257  if (failed(warpMatrixInfo))
258  return false;
259 
260  FailureOr<vector::ContractionOp> contractOp = nvgpu::getUserContract(op);
261  if (failed(contractOp))
262  return false;
263 
264  // Handle vector.extract_strided_slice on registers containing
265  // matrixB and matrixC operands. vector.extract_strided_slice op
266  // is not supported on registers containing matrixA operands.
267  if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B)
268  return (cast<VectorType>(op->getResult(0).getType()) ==
269  cast<VectorType>((*contractOp).getRhs().getType()));
270  if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::C)
271  return (cast<VectorType>(op->getResult(0).getType()) ==
272  cast<VectorType>((*contractOp).getAcc().getType()));
273 
274  return false;
275 }
276 
277 static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
278  if (isa<scf::ForOp, scf::YieldOp>(op))
279  return true;
280  if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
281  return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferRead)
282  : transferReadSupportsMMAMatrixType(transferRead);
283  if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
284  return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferWrite)
285  : transferWriteSupportsMMAMatrixType(transferWrite);
286  if (auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op))
287  return useNvGpu &&
288  extractStridedSliceSupportsMMAMatrixType(extractStridedSlice);
289  if (auto contract = dyn_cast<vector::ContractionOp>(op))
290  return contractSupportsMMAMatrixType(contract, useNvGpu);
291  if (auto constant = dyn_cast<arith::ConstantOp>(op))
292  return constantSupportsMMAMatrixType(constant);
293  if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
295  if (auto signedExtend = dyn_cast<arith::ExtSIOp>(op))
296  return integerExtendSupportsMMAMatrixType<arith::ExtSIOp>(signedExtend);
297  if (auto unsignedExtend = dyn_cast<arith::ExtUIOp>(op))
298  return integerExtendSupportsMMAMatrixType<arith::ExtUIOp>(unsignedExtend);
299  if (auto fpExtend = dyn_cast<arith::ExtFOp>(op))
300  return fpExtendSupportsMMAMatrixType(fpExtend);
302 }
303 
304 /// Return an unsorted slice handling scf.for region differently than
305 /// `getSlice`. In scf.for we only want to include as part of the slice elements
306 /// that are part of the use/def chain.
309  const BackwardSliceOptions &backwardSliceOptions,
310  const ForwardSliceOptions &forwardSliceOptions) {
312  slice.insert(op);
313  unsigned currentIndex = 0;
314  SetVector<Operation *> backwardSlice;
315  SetVector<Operation *> forwardSlice;
316  while (currentIndex != slice.size()) {
317  auto *currentOp = (slice)[currentIndex];
318  // Compute and insert the backwardSlice starting from currentOp.
319  backwardSlice.clear();
320  getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
321  slice.insert(backwardSlice.begin(), backwardSlice.end());
322 
323  // Compute and insert the forwardSlice starting from currentOp.
324  forwardSlice.clear();
325  // Special case for ForOp, we don't want to include the whole region but
326  // only the value using the region arguments.
327  // TODO: We should refine this to only care about the region arguments being
328  // converted to matrix type.
329  if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) {
330  for (Value forOpResult : forOp.getResults())
331  getForwardSlice(forOpResult, &forwardSlice, forwardSliceOptions);
332  for (BlockArgument &arg : forOp.getRegionIterArgs())
333  getForwardSlice(arg, &forwardSlice, forwardSliceOptions);
334  } else {
335  getForwardSlice(currentOp, &forwardSlice, forwardSliceOptions);
336  }
337  slice.insert(forwardSlice.begin(), forwardSlice.end());
338  ++currentIndex;
339  }
340  return slice;
341 }
342 
343 // Analyze slice of operations based on convert op to figure out if the whole
344 // slice can be converted to MMA operations.
346  bool useNvGpu) {
347  auto hasVectorDest = [](Operation *op) {
348  return llvm::any_of(op->getResultTypes(), llvm::IsaPred<VectorType>);
349  };
350  BackwardSliceOptions backwardSliceOptions;
351  backwardSliceOptions.filter = hasVectorDest;
352 
353  auto hasVectorSrc = [](Operation *op) {
354  return llvm::any_of(op->getOperandTypes(), llvm::IsaPred<VectorType>);
355  };
356  ForwardSliceOptions forwardSliceOptions;
357  forwardSliceOptions.filter = hasVectorSrc;
358 
359  SetVector<Operation *> opToConvert;
360  op->walk([&](vector::ContractionOp contract) {
361  if (opToConvert.contains(contract.getOperation()))
362  return;
363  SetVector<Operation *> dependentOps =
364  getSliceContract(contract, backwardSliceOptions, forwardSliceOptions);
365  // If any instruction cannot use MMA matrix type drop the whole
366  // chain. MMA matrix are stored in an opaque type so they cannot be used
367  // by all operations.
368  if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) {
369  if (!supportsMMaMatrixType(op, useNvGpu)) {
370  LLVM_DEBUG(DBGS() << "cannot convert op: " << *op << "\n");
371  return true;
372  }
373  return false;
374  }))
375  return;
376 
377  opToConvert.insert(dependentOps.begin(), dependentOps.end());
378  });
379  // Sort the operations so that we can convert them in topological order.
380  return topologicalSort(opToConvert);
381 }
382 
383 namespace {
384 // Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted
385 // to MMA matmul.
386 struct PrepareContractToGPUMMA
387  : public OpRewritePattern<vector::ContractionOp> {
389 
390  LogicalResult matchAndRewrite(vector::ContractionOp op,
391  PatternRewriter &rewriter) const override {
392  Location loc = op.getLoc();
393  Value lhs = op.getLhs(), rhs = op.getRhs(), res = op.getAcc();
394 
395  // Set up the parallel/reduction structure in right form.
396  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
397  auto infer = [&](MapList m) {
398  return AffineMap::inferFromExprList(m, op.getContext());
399  };
400  AffineExpr m, n, k;
401  bindDims(rewriter.getContext(), m, n, k);
402  static constexpr std::array<int64_t, 2> perm = {1, 0};
403  auto iteratorTypes = op.getIteratorTypes().getValue();
404  SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
405  if (!(vector::isParallelIterator(iteratorTypes[0]) &&
406  vector::isParallelIterator(iteratorTypes[1]) &&
407  vector::isReductionIterator(iteratorTypes[2])))
408  return rewriter.notifyMatchFailure(op, "not a gemm contraction");
409  //
410  // Two outer parallel, one inner reduction (matmat flavor).
411  //
412  // This is the classical row-major matmul, nothing to do.
413  if (maps == infer({{m, k}, {k, n}, {m, n}}))
414  return rewriter.notifyMatchFailure(op, "contraction already prepared");
415  if (maps == infer({{m, k}, {n, k}, {m, n}})) {
416  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
417  } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
418  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
419  } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
420  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
421  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
422  } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
423  std::swap(rhs, lhs);
424  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
425  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
426  } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
427  std::swap(rhs, lhs);
428  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
429  } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
430  std::swap(lhs, rhs);
431  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
432  } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
433  std::swap(lhs, rhs);
434  } else {
435  // TODO: llvm_unreachable ?
436  return rewriter.notifyMatchFailure(op, "unexpected contraction case");
437  }
438  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
439  op, lhs, rhs, res,
440  rewriter.getAffineMapArrayAttr(infer({{m, k}, {k, n}, {m, n}})),
441  op.getIteratorTypes());
442  return success();
443  }
444 };
445 
446 // Fold transpose op into the transfer read op. NVGPU mma.sync op only supports
447 // row-, column-, and row-major layout for matrixA, matrixB, and matrixC,
448 // respectively. We can fold the transpose operation when loading the data from
449 // Shared Memory to registers.
450 struct CombineTransferReadOpTranspose final
451  : public OpRewritePattern<vector::TransposeOp> {
453 
454  LogicalResult matchAndRewrite(vector::TransposeOp op,
455  PatternRewriter &rewriter) const override {
456  // Look through integer extend ops.
457  Value source = op.getVector();
458  Type resultType = op.getType();
459  Operation *extOp;
460  if ((extOp = source.getDefiningOp<arith::ExtSIOp>()) ||
461  (extOp = source.getDefiningOp<arith::ExtUIOp>()) ||
462  (extOp = source.getDefiningOp<arith::ExtFOp>())) {
463  source = extOp->getOperand(0);
464  resultType =
465  VectorType::get(cast<VectorType>(resultType).getShape(),
466  cast<VectorType>(source.getType()).getElementType());
467  }
468 
469  auto transferReadOp = source.getDefiningOp<vector::TransferReadOp>();
470  if (!transferReadOp)
471  return rewriter.notifyMatchFailure(op, "no transfer read");
472 
473  // TODO: support 0-d corner case.
474  if (transferReadOp.getTransferRank() == 0)
475  return rewriter.notifyMatchFailure(op, "0-D transfer read");
476 
477  if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim())
478  return rewriter.notifyMatchFailure(op, "not inbounds transfer read");
479 
480  AffineMap permutationMap =
481  AffineMap::getPermutationMap(op.getPermutation(), op.getContext());
482  AffineMap newMap =
483  permutationMap.compose(transferReadOp.getPermutationMap());
484 
485  auto loc = op.getLoc();
486  Value result =
487  rewriter
488  .create<vector::TransferReadOp>(
489  loc, resultType, transferReadOp.getSource(),
490  transferReadOp.getIndices(), AffineMapAttr::get(newMap),
491  transferReadOp.getPadding(), transferReadOp.getMask(),
492  transferReadOp.getInBoundsAttr())
493  .getResult();
494 
495  // Fuse through the integer extend op.
496  if (extOp) {
497  if (isa<arith::ExtSIOp>(extOp))
498  result = rewriter.create<arith::ExtSIOp>(loc, op.getType(), result)
499  .getResult();
500  else if (isa<arith::ExtUIOp>(extOp))
501  result = rewriter.create<arith::ExtUIOp>(loc, op.getType(), result)
502  .getResult();
503  else
504  result = rewriter.create<arith::ExtFOp>(loc, op.getType(), result)
505  .getResult();
506  }
507 
508  rewriter.replaceOp(op, result);
509  return success();
510  }
511 };
512 
513 } // namespace
514 
515 // MMA types have different layout based on how they are used in matmul ops.
516 // Figure the right layout to use by looking at op uses.
517 // TODO: Change the GPU dialect to abstract the layout at the this level and
518 // only care about it during lowering to NVVM.
519 static const char *inferFragType(Operation *op) {
520  // We can have arith.ext ops before reaching contract ops. See through them
521  // and other kinds of elementwise ops.
522  if (op->hasOneUse()) {
523  Operation *userOp = *op->user_begin();
524  if (userOp->hasTrait<OpTrait::Elementwise>())
525  return inferFragType(userOp);
526  }
527 
528  for (Operation *users : op->getUsers()) {
529  auto contract = dyn_cast<vector::ContractionOp>(users);
530  if (!contract)
531  continue;
532  assert(op->getNumResults() == 1);
533  if (contract.getLhs() == op->getResult(0))
534  return "AOp";
535  if (contract.getRhs() == op->getResult(0))
536  return "BOp";
537  }
538  return "COp";
539 }
540 
541 static LogicalResult
542 convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
543  llvm::DenseMap<Value, Value> &valueMapping) {
544  OpBuilder::InsertionGuard g(rewriter);
545  rewriter.setInsertionPoint(op);
546 
547  assert(op.getTransferRank() > 0 && "unexpected 0-d transfer");
549  "expected convertible operation");
550 
551  std::optional<int64_t> stride =
552  getStaticallyKnownRowStride(op.getShapedType());
553  if (!stride.has_value()) {
554  LLVM_DEBUG(DBGS() << "no stride\n");
555  return rewriter.notifyMatchFailure(op, "no stride");
556  }
557 
558  AffineMap map = op.getPermutationMap();
559  bool isTranspose = isTransposeMatrixLoadMap(map);
560 
561  // Handle broadcast by setting the stride to 0.
562  if (auto cstExpr = dyn_cast<AffineConstantExpr>(map.getResult(isTranspose))) {
563  assert(cstExpr.getValue() == 0);
564  stride = 0;
565  }
566 
567  Value mappingResult = op.getResult();
568  auto elType = op.getVectorType().getElementType();
569  const char *fragType = inferFragType(op);
570  if (op->hasOneUse()) {
571  auto *user = *op->user_begin();
572  // Infer the signedness of the mma type from the integer extend.
573  if (isa<arith::ExtSIOp, arith::ExtUIOp>(user)) {
574  elType = IntegerType::get(
575  op.getContext(), cast<IntegerType>(elType).getWidth(),
576  isa<arith::ExtSIOp>(user) ? IntegerType::Signed
577  : IntegerType::Unsigned);
578  mappingResult = user->getResult(0);
579  }
580  }
581  gpu::MMAMatrixType type =
582  gpu::MMAMatrixType::get(op.getVectorType().getShape(), elType, fragType);
583  Value load = rewriter.create<gpu::SubgroupMmaLoadMatrixOp>(
584  op.getLoc(), type, op.getSource(), op.getIndices(),
585  rewriter.getIndexAttr(*stride),
586  isTranspose ? rewriter.getUnitAttr() : UnitAttr());
587  valueMapping[mappingResult] = load;
588 
589  LLVM_DEBUG(DBGS() << "transfer read to: " << load << "\n");
590  return success();
591 }
592 
593 static LogicalResult
594 convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
595  llvm::DenseMap<Value, Value> &valueMapping) {
596  OpBuilder::InsertionGuard g(rewriter);
597  rewriter.setInsertionPoint(op);
598 
600  std::optional<int64_t> stride =
601  getStaticallyKnownRowStride(op.getShapedType());
602  if (!stride.has_value()) {
603  LLVM_DEBUG(DBGS() << "no stride\n");
604  return rewriter.notifyMatchFailure(op, "no stride");
605  }
606 
607  auto it = valueMapping.find(op.getVector());
608  if (it == valueMapping.end()) {
609  LLVM_DEBUG(DBGS() << "no mapping\n");
610  return rewriter.notifyMatchFailure(op, "no mapping");
611  }
612 
613  Value matrix = it->second;
614  auto store = rewriter.create<gpu::SubgroupMmaStoreMatrixOp>(
615  op.getLoc(), matrix, op.getSource(), op.getIndices(),
616  rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr());
617  (void)store;
618 
619  LLVM_DEBUG(DBGS() << "transfer write to: " << store << "\n");
620 
621  LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
622  rewriter.eraseOp(op);
623  return success();
624 }
625 
626 /// Returns the vector type which represents a matrix fragment.
627 static VectorType
630  regInfo.elementsPerRegister};
631  Type elType = regInfo.registerLLVMType;
632  if (auto vecType = dyn_cast<VectorType>(elType))
633  elType = vecType.getElementType();
634  return VectorType::get(shape, elType);
635 }
636 
637 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
638 static LogicalResult
639 convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op,
640  llvm::DenseMap<Value, Value> &valueMapping) {
641  OpBuilder::InsertionGuard g(rewriter);
642  rewriter.setInsertionPoint(op);
643 
644  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
646  if (failed(warpMatrixInfo)) {
647  LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
648  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
649  }
650 
651  FailureOr<nvgpu::FragmentElementInfo> regInfo =
652  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
653  if (failed(regInfo)) {
654  LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
655  return rewriter.notifyMatchFailure(op, "not mma sync reg info");
656  }
657 
658  VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
659  auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
660  if (!dense) {
661  LLVM_DEBUG(DBGS() << "not a splat\n");
662  return rewriter.notifyMatchFailure(op, "not a splat");
663  }
664 
665  Value result = rewriter.create<arith::ConstantOp>(
666  op.getLoc(), vectorType,
667  DenseElementsAttr::get(vectorType, dense.getSplatValue<Attribute>()));
668  valueMapping[op.getResult()] = result;
669  return success();
670 }
671 
672 /// Check if the loaded matrix operand requires transposed.
673 /// Transposed Map Example:
674 /// Example 1 : (..., d0, d1) -> (d1 * 1, d0 * 2)
675 /// Example 2 : (d0, d1, d2, d3) -> (d3, d2)
676 /// The code below checks if the output 2D is transposed using a generalized
677 /// version : (d0, d1, dn, ..., dm, ...) -> (dm, dn)
678 /// Returns : true; if m > n, false o.w.
679 static FailureOr<bool> isTransposed(vector::TransferReadOp op) {
681 
682  if (map.getNumResults() != 2) {
683  LLVM_DEBUG(DBGS() << "Failed because the result of `vector.transfer_read` "
684  "is not a 2d operand\n");
685  return failure();
686  }
687 
688  // Output 2D matrix dimensions in the order of d0, d1.
689  mlir::AffineExpr dM = map.getResult(0);
690  mlir::AffineExpr dN = map.getResult(1);
691 
692  // Find the position of these expressions in the input.
693  auto exprM = dyn_cast<AffineDimExpr>(dM);
694  auto exprN = dyn_cast<AffineDimExpr>(dN);
695 
696  if (!exprM || !exprN) {
697  LLVM_DEBUG(DBGS() << "Failed because expressions are not affine dim "
698  "expressions, then transpose cannot be determined.\n");
699  return failure();
700  }
701 
702  return exprM.getPosition() > exprN.getPosition();
703 }
704 
705 static LogicalResult
706 creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
707  llvm::DenseMap<Value, Value> &valueMapping) {
708  OpBuilder::InsertionGuard g(rewriter);
709  rewriter.setInsertionPoint(op);
710  Location loc = op->getLoc();
711 
712  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
714  if (failed(warpMatrixInfo)) {
715  LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
716  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
717  }
718 
719  FailureOr<nvgpu::FragmentElementInfo> regInfo =
720  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
721  if (failed(regInfo)) {
722  LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
723  return rewriter.notifyMatchFailure(op, "not mma sync reg info");
724  }
725 
726  FailureOr<bool> transpose = isTransposed(op);
727  if (failed(transpose)) {
728  LLVM_DEBUG(DBGS() << "failed to determine the transpose\n");
729  return rewriter.notifyMatchFailure(
730  op, "Op should likely not be converted to a nvgpu.ldmatrix call.");
731  }
732 
733  FailureOr<nvgpu::LdMatrixParams> params =
734  nvgpu::getLdMatrixParams(*warpMatrixInfo, *transpose);
735 
736  if (failed(params)) {
737  LLVM_DEBUG(
738  DBGS()
739  << "failed to convert vector.transfer_read to ldmatrix. "
740  << "Op should likely not be converted to a nvgpu.ldmatrix call.\n");
741  return rewriter.notifyMatchFailure(
742  op, "failed to convert vector.transfer_read to ldmatrix; this op "
743  "likely should not be converted to a nvgpu.ldmatrix call.");
744  }
745 
746  // Adjust the load offset.
747  auto laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr);
748  FailureOr<AffineMap> offsets =
749  nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params);
750  if (failed(offsets)) {
751  LLVM_DEBUG(DBGS() << "no offsets\n");
752  return rewriter.notifyMatchFailure(op, "no offsets");
753  }
754 
755  VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
756 
757  SmallVector<Value, 4> indices;
758  getXferIndices<vector::TransferReadOp>(rewriter, op, *offsets, {laneId},
759  indices);
760 
761  nvgpu::LdMatrixOp newOp = rewriter.create<nvgpu::LdMatrixOp>(
762  loc, vectorType, op.getSource(), indices, *transpose, params->numTiles);
763  valueMapping[op] = newOp->getResult(0);
764  return success();
765 }
766 
767 static LogicalResult
768 createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
769  llvm::DenseMap<Value, Value> &valueMapping) {
770  OpBuilder::InsertionGuard g(rewriter);
771  rewriter.setInsertionPoint(op);
772 
773  Location loc = op.getLoc();
774  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
776  if (failed(warpMatrixInfo))
777  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
778  FailureOr<nvgpu::FragmentElementInfo> regInfo =
779  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
780  if (failed(regInfo)) {
781  return rewriter.notifyMatchFailure(
782  op, "Failed to deduce register fragment type during "
783  "conversion to distributed non-ldmatrix compatible load");
784  }
785 
786  Value laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr);
787  SmallVector<Value, 4> elements;
788 
789  // This is the individual element type.
790  Type loadedElType = regInfo->registerLLVMType;
791  VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
792 
793  Value fill = rewriter.create<arith::ConstantOp>(
794  op.getLoc(), vectorType.getElementType(),
795  rewriter.getZeroAttr(vectorType.getElementType()));
796  Value result =
797  rewriter.create<vector::SplatOp>(op.getLoc(), fill, vectorType);
798 
799  bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
800 
801  // If we are not transposing, then we can use vectorized loads. Otherwise, we
802  // must load each element individually.
803  if (!isTransposeLoad) {
804  if (!isa<VectorType>(loadedElType)) {
805  loadedElType = VectorType::get({1}, loadedElType);
806  }
807 
808  for (int i = 0; i < vectorType.getShape()[0]; i++) {
809  FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
810  rewriter, op.getLoc(), *warpMatrixInfo);
811  if (failed(coords))
812  return rewriter.notifyMatchFailure(op, "no coords");
813 
814  Value logicalValueId = rewriter.create<arith::ConstantOp>(
815  loc, rewriter.getIndexType(),
816  rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
817  SmallVector<Value, 4> newIndices;
818  getXferIndices<vector::TransferReadOp>(
819  rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
820 
821  Value el = rewriter.create<vector::LoadOp>(loc, loadedElType,
822  op.getSource(), newIndices);
823  result = rewriter.create<vector::InsertOp>(loc, el, result, i);
824  }
825  } else {
826  if (auto vecType = dyn_cast<VectorType>(loadedElType)) {
827  loadedElType = vecType.getElementType();
828  }
829  for (int i = 0; i < vectorType.getShape()[0]; i++) {
830  for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1];
831  innerIdx++) {
832 
833  Value logicalValueId = rewriter.create<arith::ConstantOp>(
834  loc, rewriter.getIndexType(),
835  rewriter.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx));
836  FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
837  rewriter, op.getLoc(), *warpMatrixInfo);
838  if (failed(coords))
839  return rewriter.notifyMatchFailure(op, "no coords");
840 
841  SmallVector<Value, 4> newIndices;
842  getXferIndices<vector::TransferReadOp>(
843  rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
844  Value el = rewriter.create<memref::LoadOp>(op.getLoc(), loadedElType,
845  op.getSource(), newIndices);
846  result = rewriter.create<vector::InsertOp>(
847  op.getLoc(), el, result, ArrayRef<int64_t>{i, innerIdx});
848  }
849  }
850  }
851 
852  valueMapping[op.getResult()] = result;
853  return success();
854 }
855 
856 /// Return true if this is a shared memory memref type.
857 static bool isSharedMemory(MemRefType type) {
858  auto addressSpace =
859  dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
860  return addressSpace &&
861  addressSpace.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
862 }
863 
864 /// Converts a `vector.transfer_read` operation directly to either a
865 /// `vector.load` or a `nvgpu.ldmatrix` operation. This function should only be
866 /// used when converting to `nvgpu.mma.sync` operations.
867 static LogicalResult
868 convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op,
869  llvm::DenseMap<Value, Value> &valueMapping) {
870  OpBuilder::InsertionGuard g(rewriter);
871  rewriter.setInsertionPoint(op);
872 
873  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
875  if (failed(warpMatrixInfo))
876  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
877 
878  bool isLdMatrixCompatible =
879  isSharedMemory(cast<MemRefType>(op.getSource().getType())) &&
880  nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128;
881 
882  VectorType vecTy = op.getVectorType();
883  int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth();
884 
885  // When we are transposing the B operand, ldmatrix will only work if we have
886  // at least 8 rows to read and the width to read for the transpose is 128
887  // bits.
888  if (!op.getPermutationMap().isMinorIdentity() &&
889  (bitWidth != 16 || vecTy.getDimSize(1) < 8 ||
890  vecTy.getDimSize(0) * bitWidth < 128))
891  isLdMatrixCompatible = false;
892 
893  if (!isLdMatrixCompatible)
894  return createNonLdMatrixLoads(rewriter, op, valueMapping);
895 
896  return creatLdMatrixCompatibleLoads(rewriter, op, valueMapping);
897 }
898 
899 static LogicalResult
900 convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
901  llvm::DenseMap<Value, Value> &valueMapping) {
902  OpBuilder::InsertionGuard g(rewriter);
903  rewriter.setInsertionPoint(op);
904 
905  Location loc = op->getLoc();
906  auto it = valueMapping.find(op.getVector());
907  if (it == valueMapping.end())
908  return rewriter.notifyMatchFailure(op, "no mapping");
909  Value matrix = it->second;
910 
911  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
913  if (failed(warpMatrixInfo))
914  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
915  FailureOr<nvgpu::FragmentElementInfo> regInfo =
916  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
917  if (failed(regInfo))
918  return rewriter.notifyMatchFailure(op, "not mma sync reg info");
919 
920  VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
921  Value laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr);
922 
923  for (unsigned i = 0; i < vectorType.getShape()[0]; i++) {
924  Value logicalValueId = rewriter.create<arith::ConstantOp>(
925  loc, rewriter.getIndexType(),
926  rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
927  FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
928  rewriter, op.getLoc(), *warpMatrixInfo);
929  if (failed(coords))
930  return rewriter.notifyMatchFailure(op, "no coords");
931 
932  Value el =
933  rewriter.create<vector::ExtractOp>(loc, matrix, ArrayRef<int64_t>{i});
934  SmallVector<Value, 4> newIndices;
935  getXferIndices<vector::TransferWriteOp>(
936  rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
937  rewriter.create<vector::StoreOp>(loc, el, op.getSource(), newIndices);
938  }
939 
940  LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
941  rewriter.eraseOp(op);
942  return success();
943 }
944 
945 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
946  SmallVectorImpl<int64_t> &results) {
947  for (auto attr : arrayAttr)
948  results.push_back(cast<IntegerAttr>(attr).getInt());
949 }
950 
951 static LogicalResult
953  vector::ExtractStridedSliceOp op,
954  llvm::DenseMap<Value, Value> &valueMapping) {
955  OpBuilder::InsertionGuard g(rewriter);
956  rewriter.setInsertionPoint(op);
957 
958  Location loc = op->getLoc();
959 
960  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
962  if (failed(warpMatrixInfo))
963  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
964 
965  FailureOr<nvgpu::FragmentElementInfo> mmaSyncFragmentInfo =
966  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
967  if (failed(mmaSyncFragmentInfo))
968  return rewriter.notifyMatchFailure(op, "no mmaSyncFragmentInfo");
969 
970  // Find the vector.transer_read whose result vector is being sliced.
971  auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>();
972  if (!transferReadOp)
973  return rewriter.notifyMatchFailure(op, "no transfer read");
974 
975  warpMatrixInfo = nvgpu::getWarpMatrixInfo(transferReadOp);
976  if (failed(warpMatrixInfo))
977  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
978 
979  FailureOr<nvgpu::FragmentElementInfo> ldFragmentInfo =
980  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
981  if (failed(ldFragmentInfo))
982  return rewriter.notifyMatchFailure(op, "no ldFragmentInfo");
983 
984  assert(
985  (mmaSyncFragmentInfo->elementsPerRegister ==
986  ldFragmentInfo->elementsPerRegister) &&
987  "Number of elements per register should be same for load and mma.sync");
988 
989  // Create vector.extract_strided_slice op for thread-owned fragments.
990  std::array<int64_t, 2> strides = {1,
991  1}; // stride for extract slice is always 1.
992  std::array<int64_t, 2> sliceShape = {
993  mmaSyncFragmentInfo->numRegistersPerFragment,
994  mmaSyncFragmentInfo->elementsPerRegister};
995  auto it = valueMapping.find(transferReadOp);
996  if (it == valueMapping.end())
997  return rewriter.notifyMatchFailure(op, "no mapping");
998  auto sourceVector = it->second;
999 
1000  // offset and sizes at warp-level of onwership.
1001  SmallVector<int64_t> offsets;
1002  populateFromInt64AttrArray(op.getOffsets(), offsets);
1003 
1004  SmallVector<int64_t> sizes;
1005  populateFromInt64AttrArray(op.getSizes(), sizes);
1006  ArrayRef<int64_t> warpVectorShape = op.getSourceVectorType().getShape();
1007 
1008  // Compute offset in vector registers. Note that the mma.sync vector registers
1009  // are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector
1010  // registers can only be sliced along numberOfFragments, i.e., sliceOffset[0].
1011  std::array<int64_t, 2> sliceOffset = {0, 0};
1012 
1013  if (offsets[0] && offsets[1])
1014  return op->emitError() << "Slicing fragments in 2D is not supported. ";
1015  if (offsets[0])
1016  sliceOffset[0] = (warpVectorShape[0] / offsets[0]);
1017  else if (offsets[1])
1018  sliceOffset[0] = (warpVectorShape[1] / offsets[1]);
1019 
1020  Value newOp = rewriter.create<vector::ExtractStridedSliceOp>(
1021  loc, sourceVector, sliceOffset, sliceShape, strides);
1022 
1023  valueMapping[op] = newOp;
1024  return success();
1025 }
1026 
1027 static LogicalResult
1028 convertContractOp(RewriterBase &rewriter, vector::ContractionOp op,
1029  llvm::DenseMap<Value, Value> &valueMapping) {
1030  OpBuilder::InsertionGuard g(rewriter);
1031  rewriter.setInsertionPoint(op);
1032 
1033  auto itA = valueMapping.find(op.getLhs());
1034  auto itB = valueMapping.find(op.getRhs());
1035  auto itC = valueMapping.find(op.getAcc());
1036  if (itA == valueMapping.end() || itB == valueMapping.end() ||
1037  itC == valueMapping.end())
1038  return rewriter.notifyMatchFailure(op, "no mapping");
1039  Value opA = itA->second, opB = itB->second, opC = itC->second;
1040  Value matmul = rewriter.create<gpu::SubgroupMmaComputeOp>(
1041  op.getLoc(), opC.getType(), opA, opB, opC, /*a_transpose=*/UnitAttr(),
1042  /*b_transpose=*/UnitAttr());
1043  valueMapping[op.getResult()] = matmul;
1044  return success();
1045 }
1046 
1047 static LogicalResult
1048 convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op,
1049  llvm::DenseMap<Value, Value> &valueMapping) {
1050  OpBuilder::InsertionGuard g(rewriter);
1051  rewriter.setInsertionPoint(op);
1052 
1053  auto itA = valueMapping.find(op.getLhs());
1054  auto itB = valueMapping.find(op.getRhs());
1055  auto itC = valueMapping.find(op.getAcc());
1056  if (itA == valueMapping.end() || itB == valueMapping.end() ||
1057  itC == valueMapping.end())
1058  return rewriter.notifyMatchFailure(op, "no mapping");
1059  Value opA = itA->second, opB = itB->second, opC = itC->second;
1060  int64_t m = cast<VectorType>(op.getLhs().getType()).getShape()[0];
1061  int64_t n = cast<VectorType>(op.getRhs().getType()).getShape()[0];
1062  int64_t k = cast<VectorType>(op.getLhs().getType()).getShape()[1];
1063  Value matmul = rewriter.create<nvgpu::MmaSyncOp>(
1064  op.getLoc(), opA, opB, opC, rewriter.getI64ArrayAttr({m, n, k}));
1065  valueMapping[op.getResult()] = matmul;
1066  return success();
1067 }
1068 
1069 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
1070 static LogicalResult
1071 convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op,
1072  llvm::DenseMap<Value, Value> &valueMapping) {
1073  OpBuilder::InsertionGuard g(rewriter);
1074  rewriter.setInsertionPoint(op);
1075 
1076  assert(constantSupportsMMAMatrixType(op));
1077 
1078  auto splat =
1079  cast<SplatElementsAttr>(op.getValue()).getSplatValue<TypedAttr>();
1080  auto scalarConstant =
1081  rewriter.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat);
1082  const char *fragType = inferFragType(op);
1083  auto vecType = cast<VectorType>(op.getType());
1085  vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1086  auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(
1087  op.getLoc(), type, scalarConstant);
1088  valueMapping[op.getResult()] = matrix;
1089  return success();
1090 }
1091 
1092 /// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
1093 static LogicalResult
1094 convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op,
1095  llvm::DenseMap<Value, Value> &valueMapping) {
1096  OpBuilder::InsertionGuard g(rewriter);
1097  rewriter.setInsertionPoint(op);
1098 
1099  assert(broadcastSupportsMMAMatrixType(op));
1100 
1101  const char *fragType = inferFragType(op);
1102  auto vecType = op.getResultVectorType();
1104  vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1105  auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(
1106  op.getLoc(), type, op.getSource());
1107  valueMapping[op.getResult()] = matrix;
1108  return success();
1109 }
1110 
1111 // Replace ForOp with a new ForOp with extra operands. The YieldOp is not
1112 // updated and needs to be updated separately for the loop to be correct.
1113 static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
1114  scf::ForOp loop,
1115  ValueRange newInitArgs) {
1116  OpBuilder::InsertionGuard g(rewriter);
1117  rewriter.setInsertionPoint(loop);
1118 
1119  // Create a new loop before the existing one, with the extra operands.
1120  rewriter.setInsertionPoint(loop);
1121  auto operands = llvm::to_vector<4>(loop.getInitArgs());
1122  llvm::append_range(operands, newInitArgs);
1123  scf::ForOp newLoop = rewriter.create<scf::ForOp>(
1124  loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
1125  operands);
1126  rewriter.eraseBlock(newLoop.getBody());
1127 
1128  newLoop.getRegion().getBlocks().splice(
1129  newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks());
1130  for (Value operand : newInitArgs)
1131  newLoop.getBody()->addArgument(operand.getType(), operand.getLoc());
1132 
1133  for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
1134  loop.getNumResults())))
1135  rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
1136 
1137  LLVM_DEBUG(DBGS() << "newLoop now: " << newLoop << "\n");
1138  LLVM_DEBUG(DBGS() << "stripped scf.for: " << loop << "\n");
1139  LLVM_DEBUG(DBGS() << "erase: " << loop);
1140 
1141  rewriter.eraseOp(loop);
1142  return newLoop;
1143 }
1144 
1145 static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op,
1146  llvm::DenseMap<Value, Value> &valueMapping) {
1147  OpBuilder::InsertionGuard g(rewriter);
1148  rewriter.setInsertionPoint(op);
1149 
1150  SmallVector<Value> newOperands;
1152  for (const auto &operand : llvm::enumerate(op.getInitArgs())) {
1153  auto it = valueMapping.find(operand.value());
1154  if (it == valueMapping.end()) {
1155  LLVM_DEBUG(DBGS() << "no value mapping for: " << operand.value() << "\n");
1156  continue;
1157  }
1158  argMapping.push_back(std::make_pair(
1159  operand.index(), op.getInitArgs().size() + newOperands.size()));
1160  newOperands.push_back(it->second);
1161  }
1162 
1163  scf::ForOp newForOp = replaceForOpWithNewSignature(rewriter, op, newOperands);
1164  Block &loopBody = *newForOp.getBody();
1165  for (auto mapping : argMapping) {
1166  valueMapping[newForOp.getResult(mapping.first)] =
1167  newForOp.getResult(mapping.second);
1168  valueMapping[loopBody.getArgument(mapping.first +
1169  newForOp.getNumInductionVars())] =
1170  loopBody.getArgument(mapping.second + newForOp.getNumInductionVars());
1171  }
1172 
1173  LLVM_DEBUG(DBGS() << "scf.for to: " << newForOp << "\n");
1174  return success();
1175 }
1176 
1177 static LogicalResult
1178 convertYieldOp(RewriterBase &rewriter, scf::YieldOp op,
1179  llvm::DenseMap<Value, Value> &valueMapping) {
1180  OpBuilder::InsertionGuard g(rewriter);
1181  rewriter.setInsertionPoint(op);
1182 
1183  auto loop = cast<scf::ForOp>(op->getParentOp());
1184  auto yieldOperands = llvm::to_vector<4>(op.getOperands());
1185  for (const auto &operand : llvm::enumerate(op.getOperands())) {
1186  auto it = valueMapping.find(operand.value());
1187  if (it == valueMapping.end())
1188  continue;
1189  // Replace the yield of old value with the for op argument to make it easier
1190  // to remove the dead code.
1191  yieldOperands[operand.index()] = loop.getInitArgs()[operand.index()];
1192  yieldOperands.push_back(it->second);
1193  }
1194  rewriter.create<scf::YieldOp>(op.getLoc(), yieldOperands);
1195 
1196  LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
1197  rewriter.eraseOp(op);
1198  return success();
1199 }
1200 
1201 /// Convert an elementwise op to the equivalent elementwise op on MMA matrix.
1202 static LogicalResult
1204  gpu::MMAElementwiseOp opType,
1205  llvm::DenseMap<Value, Value> &valueMapping) {
1206  OpBuilder::InsertionGuard g(rewriter);
1207  rewriter.setInsertionPoint(op);
1208 
1209  SmallVector<Value> matrixOperands;
1210  for (Value operand : op->getOperands()) {
1211  auto it = valueMapping.find(operand);
1212  if (it == valueMapping.end())
1213  return rewriter.notifyMatchFailure(op, "no mapping");
1214  matrixOperands.push_back(it->second);
1215  }
1216  auto resultType = cast<gpu::MMAMatrixType>(matrixOperands[0].getType());
1217  if (opType == gpu::MMAElementwiseOp::EXTF) {
1218  // The floating point extension case has a different result type.
1219  auto vectorType = cast<VectorType>(op->getResultTypes()[0]);
1220  resultType = gpu::MMAMatrixType::get(resultType.getShape(),
1221  vectorType.getElementType(),
1222  resultType.getOperand());
1223  }
1224 
1225  Value newOp = rewriter.create<gpu::SubgroupMmaElementwiseOp>(
1226  op->getLoc(), resultType, matrixOperands, opType);
1227  valueMapping[op->getResult(0)] = newOp;
1228  return success();
1229 }
1230 
1232  bool useNvGpu) {
1233  if (!useNvGpu) {
1234  patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>(
1235  patterns.getContext());
1236  return;
1237  }
1239  patterns.add<CombineTransferReadOpTranspose>(patterns.getContext());
1240 }
1241 
1243  Operation *rootOp) {
1244  SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/false);
1245  llvm::DenseMap<Value, Value> valueMapping;
1246 
1247  auto globalRes = LogicalResult::success();
1248  for (Operation *op : ops) {
1249  LLVM_DEBUG(DBGS() << "Process op: " << *op << "\n");
1250  // Apparently callers do not want to early exit on failure here.
1251  auto res = LogicalResult::success();
1252  if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
1253  res = convertTransferReadOp(rewriter, transferRead, valueMapping);
1254  } else if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) {
1255  res = convertTransferWriteOp(rewriter, transferWrite, valueMapping);
1256  } else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) {
1257  res = convertContractOp(rewriter, contractOp, valueMapping);
1258  } else if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
1259  res = convertConstantOp(rewriter, constantOp, valueMapping);
1260  } else if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) {
1261  res = convertBroadcastOp(rewriter, broadcastOp, valueMapping);
1262  } else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
1263  res = convertForOp(rewriter, forOp, valueMapping);
1264  } else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
1265  res = convertYieldOp(rewriter, yieldOp, valueMapping);
1266  } else if (auto elementwiseType = convertElementwiseOpToMMA(op)) {
1267  res = convertElementwiseOp(rewriter, op, *elementwiseType, valueMapping);
1268  }
1269  if (failed(res))
1270  globalRes = failure();
1271  }
1272  return globalRes;
1273 }
1274 
1276  Operation *rootOp) {
1277  SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/true);
1278  llvm::DenseMap<Value, Value> valueMapping;
1279  for (Operation *op : ops) {
1281  .Case([&](vector::TransferReadOp transferReadOp) {
1282  return convertTransferReadToLoads(rewriter, transferReadOp,
1283  valueMapping);
1284  })
1285  .Case([&](vector::TransferWriteOp transferWriteOp) {
1286  return convertTransferWriteToStores(rewriter, transferWriteOp,
1287  valueMapping);
1288  })
1289  .Case([&](vector::ExtractStridedSliceOp extractStridedSliceOp) {
1290  return convertExtractStridedSlice(rewriter, extractStridedSliceOp,
1291  valueMapping);
1292  })
1293  .Case([&](vector::ContractionOp contractionOp) {
1294  return convertContractOpToMmaSync(rewriter, contractionOp,
1295  valueMapping);
1296  })
1297  .Case([&](scf::ForOp forOp) {
1298  return convertForOp(rewriter, forOp, valueMapping);
1299  })
1300  .Case([&](scf::YieldOp yieldOp) {
1301  return convertYieldOp(rewriter, yieldOp, valueMapping);
1302  })
1303  .Case([&](arith::ConstantOp constOp) {
1304  return convertConstantOpMmaSync(rewriter, constOp, valueMapping);
1305  })
1306  .Default([&](Operation *op) {
1307  return op->emitError() << "unhandled vector to mma type: " << *op;
1308  })
1309  .failed()) {
1310  return op->emitOpError()
1311  << "failed to convert op during vector-to-nvgpu conversion";
1312  }
1313  }
1314  return success();
1315 }
1316 
1317 namespace {
1318 
1319 struct ConvertVectorToGPUPass
1320  : public impl::ConvertVectorToGPUBase<ConvertVectorToGPUPass> {
1321 
1322  explicit ConvertVectorToGPUPass(bool useNvGpu_) {
1323  useNvGpu.setValue(useNvGpu_);
1324  }
1325 
1326  void runOnOperation() override {
1327  RewritePatternSet patterns(&getContext());
1328  populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue());
1329  if (failed(
1330  applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
1331  return signalPassFailure();
1332 
1333  IRRewriter rewriter(&getContext());
1334  if (useNvGpu) {
1335  if (failed(
1336  convertVectorToNVVMCompatibleMMASync(rewriter, getOperation())))
1337  return signalPassFailure();
1338  return;
1339  }
1340  (void)convertVectorToMMAOps(rewriter, getOperation());
1341  }
1342 };
1343 
1344 } // namespace
1345 
1346 std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass(bool useNvGpu) {
1347  return std::make_unique<ConvertVectorToGPUPass>(useNvGpu);
1348 }
static MLIRContext * getContext(OpFoldResult val)
#define SUBI(lhs, rhs)
Definition: LoopEmitter.cpp:37
#define MULI(lhs, rhs)
Definition: LoopEmitter.cpp:38
#define ADDI(lhs, rhs)
Definition: LoopEmitter.cpp:35
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
static LogicalResult convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertContractOp(RewriterBase &rewriter, vector::ContractionOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool isSharedMemory(MemRefType type)
Return true if this is a shared memory memref type.
static VectorType getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo &regInfo)
Returns the vector type which represents a matrix fragment.
static const char * inferFragType(Operation *op)
static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp)
static bool supportsMMaMatrixType(Operation *op, bool useNvGpu)
static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp)
Return true if the constant is a splat to a 2D vector so that it can be converted to a MMA constant m...
static bool contractSupportsMMAMatrixType(vector::ContractionOp contract, bool useNvGpu)
Definition: VectorToGPU.cpp:77
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
static bool isTransposeMatrixLoadMap(AffineMap permutationMap)
static SetVector< Operation * > getSliceContract(Operation *op, const BackwardSliceOptions &backwardSliceOptions, const ForwardSliceOptions &forwardSliceOptions)
Return an unsorted slice handling scf.for region differently than getSlice.
static LogicalResult convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp)
Return true if this integer extend op can be folded into a contract op.
static LogicalResult convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
Converts a vector.transfer_read operation directly to either a vector.load or a nvgpu....
static LogicalResult convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
static LogicalResult creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
static FailureOr< bool > isTransposed(vector::TransferReadOp op)
Check if the loaded matrix operand requires transposed.
static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, scf::ForOp loop, ValueRange newInitArgs)
static bool transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp)
static std::optional< gpu::MMAElementwiseOp > convertElementwiseOpToMMA(Operation *op)
Return the MMA elementwise enum associated with op if it is supported.
static LogicalResult convertElementwiseOp(RewriterBase &rewriter, Operation *op, gpu::MMAElementwiseOp opType, llvm::DenseMap< Value, Value > &valueMapping)
Convert an elementwise op to the equivalent elementwise op on MMA matrix.
static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp)
static bool extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op)
Returns true if the extract strided slice op is supported with mma.sync path.
static LogicalResult convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static bool elementwiseSupportsMMAMatrixType(Operation *op)
Return true if the op is supported as elementwise op on MMAMatrix type.
static SetVector< Operation * > getOpToConvert(mlir::Operation *op, bool useNvGpu)
static LogicalResult convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static LogicalResult convertYieldOp(RewriterBase &rewriter, scf::YieldOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op, llvm::DenseMap< Value, Value > &valueMapping)
#define DBGS()
Definition: VectorToGPU.cpp:40
static std::optional< int64_t > getStaticallyKnownRowStride(ShapedType type)
static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp, AffineMap offsetMap, ArrayRef< Value > dimValues, SmallVector< Value, 4 > &indices)
For a vector TransferOpType xferOp, an empty indices vector, and an AffineMap representing offsets to...
Definition: VectorToGPU.cpp:57
static LogicalResult convertExtractStridedSlice(RewriterBase &rewriter, vector::ExtractStridedSliceOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp)
Return true if this is a broadcast from scalar to a 2D vector.
static LogicalResult createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
MLIRContext * getContext() const
Definition: AffineMap.cpp:343
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
Definition: AffineMap.cpp:155
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
Definition: AffineMap.cpp:394
unsigned getNumResults() const
Definition: AffineMap.cpp:402
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:411
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:264
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:556
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:312
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
UnitAttr getUnitAttr()
Definition: Builders.cpp:138
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:412
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:364
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:404
MLIRContext * getContext() const
Definition: Builders.h:56
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:321
IndexType getIndexType()
Definition: Builders.cpp:95
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:358
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:772
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:357
This class helps build Operations.
Definition: Builders.h:216
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:407
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:845
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:869
user_iterator user_begin()
Definition: Operation.h:865
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:644
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition: GPUDialect.h:131
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Definition: GPUDialect.cpp:122
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1144
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
int64_t inferTileWidthInBits(const WarpMatrixInfo &type)
Returns the number of bits in a single tile row.
Definition: MMAUtils.cpp:87
FailureOr< vector::ContractionOp > getUserContract(Operation *op)
Returns the first user of the op that is vector.contract.
Definition: MMAUtils.cpp:50
FailureOr< AffineMap > getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc, const WarpMatrixInfo &fragmentType)
Returns an AffineMap which maps a two dimensions representing (laneId, logicalValueId) and returns tw...
Definition: MMAUtils.cpp:173
FailureOr< WarpMatrixInfo > getWarpMatrixInfo(Operation *op)
If op is a vector.transfer_write, return the WarpMatrixInfo for the vector operand.
Definition: MMAUtils.cpp:58
FailureOr< AffineMap > getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc, const LdMatrixParams &params)
Returns an AffineMap which maps a single dimension representing the laneId to two results representin...
Definition: MMAUtils.cpp:238
FailureOr< LdMatrixParams > getLdMatrixParams(const WarpMatrixInfo &type, bool transpose)
Given type that contains info for a warp-matrix operand and whether or not the load is a transposed l...
Definition: MMAUtils.cpp:209
FailureOr< FragmentElementInfo > getMmaSyncRegisterType(const WarpMatrixInfo &type)
Returns a FragmentElementInfo struct describing the register types for the given matrix fragment type...
Definition: MMAUtils.cpp:100
bool canLowerToWarpMatrixOperation(vector::TransferReadOp op)
Returns whether the vector.transfer_read instruction can be interpreted as a warp-level cooperative m...
Definition: MMAUtils.cpp:276
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Definition: VectorOps.h:152
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Definition: VectorOps.h:147
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction wit...
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Definition: XeGPUOps.cpp:22
Include the generated interface declarations.
void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, bool useNvGpu=false)
Patterns to transform vector ops into a canonical form to convert to MMA matrix operations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:348
void getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter, Operation *rootOp)
Convert vector ops ops nested under rootOp to vector and GPU operaitons compatible with the nvvm....
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:641
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::unique_ptr< Pass > createConvertVectorToGPUPass(bool useNvGpu=false)
Convert from vector to GPU ops.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:617
LogicalResult convertVectorToMMAOps(RewriterBase &rewriter, Operation *rootOp)
Convert vector ops to MMA matrix operations nested under rootOp.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This trait tags element-wise ops on vectors or tensors.
TransitiveFilter filter
Definition: SliceAnalysis.h:30
Specifies information about the registers which compose a matrix fragment according to the PTX docume...
Definition: MMAUtils.h:52