MLIR  16.0.0git
VectorToGPU.cpp
Go to the documentation of this file.
1 //===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of vector operations to GPU dialect ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include <type_traits>
16 
28 #include "mlir/IR/Builders.h"
29 #include "mlir/Pass/Pass.h"
31 #include "mlir/Transforms/Passes.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 
34 namespace mlir {
35 #define GEN_PASS_DEF_CONVERTVECTORTOGPU
36 #include "mlir/Conversion/Passes.h.inc"
37 } // namespace mlir
38 
39 using namespace mlir;
40 
41 /// For a vector TransferOpType `xferOp`, an empty `indices` vector, and an
42 /// AffineMap representing offsets to apply to indices, the function fills
43 /// `indices` with the original indices plus the offsets. The offsets are
44 /// applied by taking into account the permutation map of the transfer op. If
45 /// the `offsetMap` has dimension placeholders, those should be provided in
46 /// `dimValues`.
47 template <typename TransferOpType>
48 static void getXferIndices(OpBuilder &b, TransferOpType xferOp,
49  AffineMap offsetMap, ArrayRef<Value> dimValues,
50  SmallVector<Value, 4> &indices) {
51  indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end());
52  Location loc = xferOp.getLoc();
53  unsigned offsetsIdx = 0;
54  for (auto expr : xferOp.getPermutationMap().getResults()) {
55  if (auto dim = expr.template dyn_cast<AffineDimExpr>()) {
56  Value prevIdx = indices[dim.getPosition()];
57  SmallVector<Value, 3> dims(dimValues.begin(), dimValues.end());
58  dims.push_back(prevIdx);
59  AffineExpr d0 = b.getAffineDimExpr(offsetMap.getNumDims());
60  indices[dim.getPosition()] = makeComposedAffineApply(
61  b, loc, d0 + offsetMap.getResult(offsetsIdx++), dims);
62  continue;
63  }
64  }
65 }
66 
67 // Return true if the contract op can be convert to MMA matmul.
68 static bool contractSupportsMMAMatrixType(vector::ContractionOp contract,
69  bool useNvGpu) {
70  if (!contract.getMasks().empty())
71  return false;
72 
73  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
74  auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
75  AffineExpr m, n, k;
76  bindDims(contract.getContext(), m, n, k);
77  auto iteratorTypes = contract.getIteratorTypes().getValue();
78  if (!(vector::isParallelIterator(iteratorTypes[0]) &&
79  vector::isParallelIterator(iteratorTypes[1]) &&
80  vector::isReductionIterator(iteratorTypes[2])))
81  return false;
82 
83  // The contract needs to represent a matmul to be able to convert to
84  // MMAMatrix matmul.
85  if (!useNvGpu &&
86  contract.getIndexingMapsArray() != infer({{m, k}, {k, n}, {m, n}}))
87  return false;
88  if (useNvGpu &&
89  contract.getIndexingMapsArray() != infer({{m, k}, {n, k}, {m, n}}))
90  return false;
91 
92  return true;
93 }
94 
95 // Return true if the given map represents a transposed matrix load,
96 // i.e. (d0, d1, ...) -> (dn-1, dn-2).
97 static bool isTransposeMatrixLoadMap(OpBuilder &b, AffineMap permutationMap) {
98  auto nDim = permutationMap.getNumDims();
99  if (nDim < 2)
100  return false;
101 
102  AffineExpr innerDim = b.getAffineDimExpr(nDim - 1);
103  AffineExpr outerDim = b.getAffineDimExpr(nDim - 2);
104  return permutationMap ==
105  AffineMap::get(nDim, 0, {innerDim, outerDim}, b.getContext());
106 }
107 
108 // Return the stide for the dimension 0 of |type| if it is a memref and has a
109 // constant stride.
110 static std::optional<int64_t>
112  auto memrefType = type.dyn_cast<MemRefType>();
113  if (!memrefType)
114  return false;
115  // If the memref is 0 or 1D the horizontal stride is 0.
116  if (memrefType.getRank() < 2)
117  return 0;
118  int64_t offset = 0;
119  SmallVector<int64_t, 2> strides;
120  if (failed(getStridesAndOffset(memrefType, strides, offset)) ||
121  strides.back() != 1)
122  return std::nullopt;
123  int64_t stride = strides[strides.size() - 2];
124  if (stride == ShapedType::kDynamic)
125  return std::nullopt;
126  return stride;
127 }
128 
129 // Return true if the transfer op can be converted to a MMA matrix load.
130 static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp,
131  bool useNvGpu) {
132  if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
133  readOp.getVectorType().getRank() != 2)
134  return false;
135  if (!getMemrefConstantHorizontalStride(readOp.getShapedType()))
136  return false;
137  AffineMap map = readOp.getPermutationMap();
138  OpBuilder b(readOp.getContext());
139  AffineExpr innerDim = b.getAffineDimExpr(map.getNumDims() - 1);
140  AffineExpr zero = b.getAffineConstantExpr(0);
141  auto broadcastInnerDim = AffineMap::get(map.getNumDims(), 0, {zero, innerDim},
142  readOp.getContext());
143 
144  if (!useNvGpu) {
145  bool result = map.isMinorIdentity() || map == broadcastInnerDim ||
146  isTransposeMatrixLoadMap(b, map);
147  return result;
148  }
149 
150  return true;
151 }
152 
153 // Return true if the transfer op can be converted to a MMA matrix store.
154 static bool
155 transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
156  // TODO: support 0-d corner case.
157  if (writeOp.getTransferRank() == 0)
158  return false;
159 
160  if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() ||
161  writeOp.getVectorType().getRank() != 2)
162  return false;
163  if (!getMemrefConstantHorizontalStride(writeOp.getShapedType()))
164  return false;
165  // TODO: Support transpose once it is added to GPU dialect ops.
166  if (!writeOp.getPermutationMap().isMinorIdentity())
167  return false;
168  return true;
169 }
170 
171 /// Return true if the constant is a splat to a 2D vector so that it can be
172 /// converted to a MMA constant matrix op.
173 static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) {
174  auto vecType = constantOp.getType().dyn_cast<VectorType>();
175  if (!vecType || vecType.getRank() != 2)
176  return false;
177  return constantOp.getValue().isa<SplatElementsAttr>();
178 }
179 
180 /// Return true if this is a broadcast from scalar to a 2D vector.
181 static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {
182  return broadcastOp.getVectorType().getRank() == 2 &&
183  broadcastOp.getSource().getType().isa<FloatType>();
184 }
185 
186 /// Return the MMA elementwise enum associated with `op` if it is supported.
187 /// Return `std::nullopt` otherwise.
188 static std::optional<gpu::MMAElementwiseOp>
190  if (isa<arith::AddFOp>(op))
191  return gpu::MMAElementwiseOp::ADDF;
192  if (isa<arith::MulFOp>(op))
193  return gpu::MMAElementwiseOp::MULF;
194  if (isa<arith::SubFOp>(op))
195  return gpu::MMAElementwiseOp::SUBF;
196  if (isa<arith::MaxFOp>(op))
197  return gpu::MMAElementwiseOp::MAXF;
198  if (isa<arith::MinFOp>(op))
199  return gpu::MMAElementwiseOp::MINF;
200  if (isa<arith::DivFOp>(op))
201  return gpu::MMAElementwiseOp::DIVF;
202  if (isa<arith::AddIOp>(op))
203  return gpu::MMAElementwiseOp::ADDI;
204  if (isa<arith::MulIOp>(op))
205  return gpu::MMAElementwiseOp::MULI;
206  if (isa<arith::SubIOp>(op))
207  return gpu::MMAElementwiseOp::SUBI;
208  if (isa<arith::DivSIOp>(op))
209  return gpu::MMAElementwiseOp::DIVS;
210  if (isa<arith::DivUIOp>(op))
211  return gpu::MMAElementwiseOp::DIVU;
212  if (isa<arith::NegFOp>(op))
213  return gpu::MMAElementwiseOp::NEGATEF;
214  return std::nullopt;
215 }
216 
217 /// Return true if the op is supported as elementwise op on MMAMatrix type.
219  return convertElementwiseOpToMMA(op).has_value();
220 }
221 
222 /// Returns true if the extract strided slice op is supported with `mma.sync`
223 /// path.
224 static bool
225 extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op) {
226 
227  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
229  if (failed(warpMatrixInfo))
230  return false;
231 
233  if (failed(contractOp))
234  return false;
235 
236  // Handle vector.extract_strided_slice on registers containing
237  // matrixB and matrixC operands. vector.extract_strided_slice op
238  // is not supported on registers containing matrixA operands.
239  if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B)
240  return (op->getResult(0).getType().cast<VectorType>() ==
241  (*contractOp).getRhs().getType().cast<VectorType>());
242  if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::C)
243  return (op->getResult(0).getType().cast<VectorType>() ==
244  (*contractOp).getAcc().getType().cast<VectorType>());
245 
246  return false;
247 }
248 
249 static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
250  if (isa<scf::ForOp, scf::YieldOp>(op))
251  return true;
252  if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
253  return transferReadSupportsMMAMatrixType(transferRead, useNvGpu);
254  if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
255  return transferWriteSupportsMMAMatrixType(transferWrite);
256  if (auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op))
257  return useNvGpu &&
258  extractStridedSliceSupportsMMAMatrixType(extractStridedSlice);
259  if (auto contract = dyn_cast<vector::ContractionOp>(op))
260  return contractSupportsMMAMatrixType(contract, useNvGpu);
261  if (auto constant = dyn_cast<arith::ConstantOp>(op))
262  return constantSupportsMMAMatrixType(constant);
263  if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
266 }
267 
268 /// Return an unsorted slice handling scf.for region differently than
269 /// `getSlice`. In scf.for we only want to include as part of the slice elements
270 /// that are part of the use/def chain.
272  TransitiveFilter backwardFilter,
273  TransitiveFilter forwardFilter) {
275  slice.insert(op);
276  unsigned currentIndex = 0;
277  SetVector<Operation *> backwardSlice;
278  SetVector<Operation *> forwardSlice;
279  while (currentIndex != slice.size()) {
280  auto *currentOp = (slice)[currentIndex];
281  // Compute and insert the backwardSlice starting from currentOp.
282  backwardSlice.clear();
283  getBackwardSlice(currentOp, &backwardSlice, backwardFilter);
284  slice.insert(backwardSlice.begin(), backwardSlice.end());
285 
286  // Compute and insert the forwardSlice starting from currentOp.
287  forwardSlice.clear();
288  // Special case for ForOp, we don't want to include the whole region but
289  // only the value using the region arguments.
290  // TODO: We should refine this to only care about the region arguments being
291  // converted to matrix type.
292  if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) {
293  for (Value forOpResult : forOp.getResults())
294  getForwardSlice(forOpResult, &forwardSlice, forwardFilter);
295  for (BlockArgument &arg : forOp.getRegionIterArgs())
296  getForwardSlice(arg, &forwardSlice, forwardFilter);
297  } else {
298  getForwardSlice(currentOp, &forwardSlice, forwardFilter);
299  }
300  slice.insert(forwardSlice.begin(), forwardSlice.end());
301  ++currentIndex;
302  }
303  return slice;
304 }
305 
306 // Analyze slice of operations based on convert op to figure out if the whole
307 // slice can be converted to MMA operations.
309  bool useNvGpu) {
310  auto hasVectorDest = [](Operation *op) {
311  return llvm::any_of(op->getResultTypes(),
312  [](Type t) { return t.isa<VectorType>(); });
313  };
314  auto hasVectorSrc = [](Operation *op) {
315  return llvm::any_of(op->getOperandTypes(),
316  [](Type t) { return t.isa<VectorType>(); });
317  };
318  SetVector<Operation *> opToConvert;
319  op->walk([&](vector::ContractionOp contract) {
320  if (opToConvert.contains(contract.getOperation()))
321  return;
322  SetVector<Operation *> dependentOps =
323  getSliceContract(contract, hasVectorDest, hasVectorSrc);
324  // If any instruction cannot use MMA matrix type drop the whole
325  // chain. MMA matrix are stored in an opaque type so they cannot be used
326  // by all operations.
327  if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) {
328  return !supportsMMaMatrixType(op, useNvGpu);
329  }))
330  return;
331  opToConvert.insert(dependentOps.begin(), dependentOps.end());
332  });
333  // Sort the operations so that we can convert them in topological order.
334  return topologicalSort(opToConvert);
335 }
336 
337 namespace {
338 // Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted
339 // to MMA matmul.
340 struct PrepareContractToGPUMMA
341  : public OpRewritePattern<vector::ContractionOp> {
343 
344  LogicalResult matchAndRewrite(vector::ContractionOp op,
345  PatternRewriter &rewriter) const override {
346  Location loc = op.getLoc();
347  Value lhs = op.getLhs(), rhs = op.getRhs(), res = op.getAcc();
348 
349  // Set up the parallel/reduction structure in right form.
350  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
351  auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
352  AffineExpr m, n, k;
353  bindDims(rewriter.getContext(), m, n, k);
354  static constexpr std::array<int64_t, 2> perm = {1, 0};
355  auto iteratorTypes = op.getIteratorTypes().getValue();
356  SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
357  if (!(vector::isParallelIterator(iteratorTypes[0]) &&
358  vector::isParallelIterator(iteratorTypes[1]) &&
359  vector::isReductionIterator(iteratorTypes[2])))
360  return failure();
361  //
362  // Two outer parallel, one inner reduction (matmat flavor).
363  //
364  if (maps == infer({{m, k}, {k, n}, {m, n}})) {
365  // This is the classical row-major matmul, nothing to do.
366  return failure();
367  }
368  if (maps == infer({{m, k}, {n, k}, {m, n}})) {
369  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
370  } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
371  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
372  } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
373  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
374  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
375  } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
376  std::swap(rhs, lhs);
377  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
378  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
379  } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
380  std::swap(rhs, lhs);
381  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
382  } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
383  std::swap(lhs, rhs);
384  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
385  } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
386  std::swap(lhs, rhs);
387  } else {
388  return failure();
389  }
390  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
391  op, lhs, rhs, res,
392  rewriter.getAffineMapArrayAttr(infer({{m, k}, {k, n}, {m, n}})),
393  op.getIteratorTypes());
394  return success();
395  }
396 };
397 
398 // Fold transpose op into the transfer read op. Nvgpu mma.sync op only supports
399 // row-, column-, and row-major layout for matrixA, matrixB, and matrixC,
400 // respectively. We can fold the transpose operation when loading the data from
401 // Shared Memory to registers.
402 struct CombineTransferReadOpTranspose final
403  : public OpRewritePattern<vector::TransposeOp> {
405 
406  LogicalResult matchAndRewrite(vector::TransposeOp op,
407  PatternRewriter &rewriter) const override {
408  auto transferReadOp =
409  op.getVector().getDefiningOp<vector::TransferReadOp>();
410  if (!transferReadOp)
411  return failure();
412 
413  // TODO: support 0-d corner case.
414  if (transferReadOp.getTransferRank() == 0)
415  return failure();
416 
417  if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim())
418  return failure();
420  op.getTransp(perm);
422  for (int64_t o : perm)
423  permU.push_back(unsigned(o));
424  AffineMap permutationMap =
425  AffineMap::getPermutationMap(permU, op.getContext());
426  AffineMap newMap =
427  permutationMap.compose(transferReadOp.getPermutationMap());
428  rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
429  op, op.getType(), transferReadOp.getSource(),
430  transferReadOp.getIndices(), AffineMapAttr::get(newMap),
431  transferReadOp.getPadding(), transferReadOp.getMask(),
432  transferReadOp.getInBoundsAttr());
433  return success();
434  }
435 };
436 
437 } // namespace
438 
439 // MMA types have different layout based on how they are used in matmul ops.
440 // Figure the right layout to use by looking at op uses.
441 // TODO: Change the GPU dialect to abstract the layout at the this level and
442 // only care about it during lowering to NVVM.
443 template <typename OpTy>
444 static const char *inferFragType(OpTy op) {
445  for (Operation *users : op->getUsers()) {
446  auto contract = dyn_cast<vector::ContractionOp>(users);
447  if (!contract)
448  continue;
449  if (contract.getLhs() == op.getResult())
450  return "AOp";
451  if (contract.getRhs() == op.getResult())
452  return "BOp";
453  }
454  return "COp";
455 }
456 
457 static void convertTransferReadOp(vector::TransferReadOp op,
458  llvm::DenseMap<Value, Value> &valueMapping) {
459  assert(op.getTransferRank() > 0 && "unexpected 0-d transfer");
460  assert(transferReadSupportsMMAMatrixType(op, /*useNvGpu=*/false));
461  std::optional<int64_t> stride =
462  getMemrefConstantHorizontalStride(op.getShapedType());
463  AffineMap map = op.getPermutationMap();
464  // Handle broadcast by setting the stride to 0.
465  if (map.getResult(0).isa<AffineConstantExpr>()) {
466  assert(map.getResult(0).cast<AffineConstantExpr>().getValue() == 0);
467  stride = 0;
468  }
469  assert(stride);
470  const char *fragType = inferFragType(op);
471  gpu::MMAMatrixType type =
472  gpu::MMAMatrixType::get(op.getVectorType().getShape(),
473  op.getVectorType().getElementType(), fragType);
474  OpBuilder b(op);
475  bool isTranspose = isTransposeMatrixLoadMap(b, map);
476  Value load = b.create<gpu::SubgroupMmaLoadMatrixOp>(
477  op.getLoc(), type, op.getSource(), op.getIndices(),
478  b.getIndexAttr(*stride), isTranspose ? b.getUnitAttr() : UnitAttr());
479  valueMapping[op.getResult()] = load;
480 }
481 
482 static void convertTransferWriteOp(vector::TransferWriteOp op,
483  llvm::DenseMap<Value, Value> &valueMapping) {
485  std::optional<int64_t> stride =
486  getMemrefConstantHorizontalStride(op.getShapedType());
487  assert(stride);
488  OpBuilder b(op);
489  Value matrix = valueMapping.find(op.getVector())->second;
490  b.create<gpu::SubgroupMmaStoreMatrixOp>(
491  op.getLoc(), matrix, op.getSource(), op.getIndices(),
492  b.getIndexAttr(*stride), /*transpose=*/UnitAttr());
493  op.erase();
494 }
495 
496 /// Returns the vector type which represents a matrix fragment.
497 static VectorType
500  regInfo.elementsPerRegister};
501  Type elType = regInfo.registerLLVMType;
502  if (auto vecType = elType.dyn_cast<VectorType>())
503  elType = vecType.getElementType();
504  return VectorType::get(shape, elType);
505 }
506 
507 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
508 static LogicalResult
509 convertConstantOpMmaSync(arith::ConstantOp op,
510  llvm::DenseMap<Value, Value> &valueMapping) {
511  OpBuilder b(op);
512  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
514  if (failed(warpMatrixInfo))
515  return failure();
516 
518  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
519  if (failed(regInfo))
520  return failure();
521 
522  VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
523  auto dense = op.getValue().dyn_cast<SplatElementsAttr>();
524  if (!dense)
525  return failure();
526  Value result = b.create<arith::ConstantOp>(
527  op.getLoc(), vectorType,
528  DenseElementsAttr::get(vectorType, dense.getSplatValue<Attribute>()));
529  valueMapping[op.getResult()] = result;
530  return success();
531 }
532 
533 static LogicalResult
534 creatLdMatrixCompatibleLoads(vector::TransferReadOp op, OpBuilder &builder,
535  llvm::DenseMap<Value, Value> &valueMapping) {
536  Location loc = op->getLoc();
537 
538  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
540  if (failed(warpMatrixInfo))
541  return failure();
542 
544  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
545  if (failed(regInfo))
546  return failure();
547 
549  *warpMatrixInfo,
550  /*transpose=*/!op.getPermutationMap().isMinorIdentity());
551  if (failed(params)) {
552  return op->emitError()
553  << "failed to convert vector.transfer_read to ldmatrix; this op "
554  "likely "
555  "should not be converted to a nvgpu.ldmatrix call.";
556  }
557 
558  // Adjust the load offset.
559  auto laneId = builder.create<gpu::LaneIdOp>(loc);
560  FailureOr<AffineMap> offsets =
561  nvgpu::getLaneIdToLdMatrixMatrixCoord(loc, builder, *params);
562  if (failed(offsets))
563  return failure();
564 
565  VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
566 
567  SmallVector<Value, 4> indices;
568  getXferIndices<vector::TransferReadOp>(builder, op, *offsets, {laneId},
569  indices);
570  nvgpu::LdMatrixOp newOp = builder.create<nvgpu::LdMatrixOp>(
571  loc, vectorType, op.getSource(), indices,
572  !op.getPermutationMap().isMinorIdentity(), params->numTiles);
573  valueMapping[op] = newOp->getResult(0);
574  return success();
575 }
576 
577 static LogicalResult
578 createNonLdMatrixLoads(vector::TransferReadOp op, OpBuilder &builder,
579  llvm::DenseMap<Value, Value> &valueMapping) {
580  Location loc = op.getLoc();
581  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
583  if (failed(warpMatrixInfo))
584  return failure();
586  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
587  if (failed(regInfo)) {
588  op->emitError() << "Failed to deduce register fragment type during "
589  "conversion to distributed non-ldmatrix compatible load";
590  return failure();
591  }
592 
593  Value laneId = builder.create<gpu::LaneIdOp>(loc);
594  SmallVector<Value, 4> elements;
595 
596  // This is the individual element type.
597  Type loadedElType = regInfo->registerLLVMType;
598  VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
599 
600  Value fill = builder.create<arith::ConstantOp>(
601  op.getLoc(), vectorType.getElementType(),
602  builder.getZeroAttr(vectorType.getElementType()));
603  Value result = builder.create<vector::SplatOp>(op.getLoc(), fill, vectorType);
604 
605  bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
606 
607  // If we are not transposing, then we can use vectorized loads. Otherwise, we
608  // must load each element individually.
609  if (!isTransposeLoad) {
610  if (!loadedElType.isa<VectorType>()) {
611  loadedElType = VectorType::get({1}, loadedElType);
612  }
613 
614  for (int i = 0; i < vectorType.getShape()[0]; i++) {
616  op.getLoc(), builder, *warpMatrixInfo);
617  if (failed(coords))
618  return failure();
619  Value logicalValueId = builder.create<arith::ConstantOp>(
620  loc, builder.getIndexType(),
621  builder.getIndexAttr(i * regInfo->elementsPerRegister));
622  SmallVector<Value, 4> newIndices;
623  getXferIndices<vector::TransferReadOp>(
624  builder, op, *coords, {laneId, logicalValueId}, newIndices);
625 
626  Value el = builder.create<vector::LoadOp>(loc, loadedElType,
627  op.getSource(), newIndices);
628  result = builder.create<vector::InsertOp>(loc, el, result,
629  builder.getI64ArrayAttr(i));
630  }
631  } else {
632  if (auto vecType = loadedElType.dyn_cast<VectorType>()) {
633  loadedElType = vecType.getElementType();
634  }
635  for (int i = 0; i < vectorType.getShape()[0]; i++) {
636  for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1];
637  innerIdx++) {
638 
639  Value logicalValueId = builder.create<arith::ConstantOp>(
640  loc, builder.getIndexType(),
641  builder.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx));
643  op.getLoc(), builder, *warpMatrixInfo);
644  if (failed(coords))
645  return failure();
646 
647  SmallVector<Value, 4> newIndices;
648  getXferIndices<vector::TransferReadOp>(
649  builder, op, *coords, {laneId, logicalValueId}, newIndices);
650  Value el = builder.create<memref::LoadOp>(op.getLoc(), loadedElType,
651  op.getSource(), newIndices);
652  result = builder.create<vector::InsertOp>(
653  op.getLoc(), el, result, builder.getI64ArrayAttr({i, innerIdx}));
654  }
655  }
656  }
657 
658  valueMapping[op.getResult()] = result;
659  return success();
660 }
661 
662 /// Converts a `vector.transfer_read` operation directly to either a
663 /// `vector.load` or a `nvgpu.ldmatrix` operation. This function should only be
664 /// used when converting to `nvgpu.mma.sync` operations.
665 static LogicalResult
666 convertTransferReadToLoads(vector::TransferReadOp op,
667  llvm::DenseMap<Value, Value> &valueMapping) {
668  OpBuilder b(op);
669 
670  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
672  if (failed(warpMatrixInfo))
673  return failure();
674 
675  bool isLdMatrixCompatible =
676  op.getSource().getType().cast<MemRefType>().getMemorySpaceAsInt() == 3 &&
677  nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128;
678 
679  VectorType vecTy = op.getVectorType();
680  int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth();
681 
682  // When we are transposing the B operand, ldmatrix will only work if we have
683  // at least 8 rows to read and the width to read for the transpose is 128
684  // bits.
685  if (!op.getPermutationMap().isMinorIdentity() &&
686  (bitWidth != 16 || vecTy.getDimSize(1) < 8 ||
687  vecTy.getDimSize(0) * bitWidth < 128))
688  isLdMatrixCompatible = false;
689 
690  if (!isLdMatrixCompatible)
691  return createNonLdMatrixLoads(op, b, valueMapping);
692 
693  return creatLdMatrixCompatibleLoads(op, b, valueMapping);
694 }
695 
696 static LogicalResult
697 convertTransferWriteToStores(vector::TransferWriteOp op,
698  llvm::DenseMap<Value, Value> &valueMapping) {
699  OpBuilder b(op);
700  Location loc = op->getLoc();
701  Value matrix = valueMapping.find(op.getVector())->second;
702 
703  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
705  if (failed(warpMatrixInfo))
706  return failure();
708  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
709  if (failed(regInfo))
710  return failure();
711 
712  VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
713  Value laneId = b.create<gpu::LaneIdOp>(loc);
714 
715  for (unsigned i = 0; i < vectorType.getShape()[0]; i++) {
716  Value logicalValueId = b.create<arith::ConstantOp>(
717  loc, b.getIndexType(),
718  b.getIndexAttr(i * regInfo->elementsPerRegister));
720  op.getLoc(), b, *warpMatrixInfo);
721  if (failed(coords))
722  return failure();
723 
724  Value el = b.create<vector::ExtractOp>(loc, matrix, ArrayRef<int64_t>{i});
725  SmallVector<Value, 4> newIndices;
726  getXferIndices<vector::TransferWriteOp>(
727  b, op, *coords, {laneId, logicalValueId}, newIndices);
728  b.create<vector::StoreOp>(loc, el, op.getSource(), newIndices);
729  }
730  op->erase();
731  return success();
732 }
733 
734 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
735  SmallVectorImpl<int64_t> &results) {
736  for (auto attr : arrayAttr)
737  results.push_back(attr.cast<IntegerAttr>().getInt());
738 }
739 
740 static LogicalResult
741 convertExtractStridedSlice(vector::ExtractStridedSliceOp op,
742  llvm::DenseMap<Value, Value> &valueMapping) {
743 
744  OpBuilder b(op);
745  Location loc = op->getLoc();
746 
747  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
749  if (failed(warpMatrixInfo))
750  return failure();
751 
752  FailureOr<nvgpu::FragmentElementInfo> mmaSyncFragmentInfo =
753  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
754  if (failed(mmaSyncFragmentInfo))
755  return failure();
756 
757  // Find the vector.transer_read whose result vector is being sliced.
758  auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>();
759  if (!transferReadOp)
760  return failure();
761 
762  warpMatrixInfo = nvgpu::getWarpMatrixInfo(transferReadOp);
763  if (failed(warpMatrixInfo))
764  return failure();
765 
767  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
768  if (failed(ldFragmentInfo))
769  return failure();
770 
771  assert(
772  (mmaSyncFragmentInfo->elementsPerRegister ==
773  ldFragmentInfo->elementsPerRegister) &&
774  "Number of elements per register should be same for load and mma.sync");
775 
776  // Create vector.extract_strided_slice op for thread-owned fragments.
777  std::array<int64_t, 2> strides = {1,
778  1}; // stride for extract slice is always 1.
779  std::array<int64_t, 2> sliceShape = {
780  mmaSyncFragmentInfo->numRegistersPerFragment,
781  mmaSyncFragmentInfo->elementsPerRegister};
782  auto sourceVector = valueMapping.find(transferReadOp)->second;
783 
784  // offset and sizes at warp-level of onwership.
785  SmallVector<int64_t> offsets;
786  populateFromInt64AttrArray(op.getOffsets(), offsets);
787 
788  SmallVector<int64_t> sizes;
789  populateFromInt64AttrArray(op.getSizes(), sizes);
790  ArrayRef<int64_t> warpVectorShape = op.getVectorType().getShape();
791 
792  // Compute offset in vector registers. Note that the mma.sync vector registers
793  // are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector
794  // registers can only be sliced along numberOfFragments, i.e., sliceOffset[0].
795  std::array<int64_t, 2> sliceOffset = {0, 0};
796 
797  if (offsets[0] && offsets[1])
798  return op->emitError() << "Slicing fragments in 2D is not supported. ";
799  if (offsets[0])
800  sliceOffset[0] = (warpVectorShape[0] / offsets[0]);
801  else if (offsets[1])
802  sliceOffset[0] = (warpVectorShape[1] / offsets[1]);
803 
804  Value newOp = b.create<vector::ExtractStridedSliceOp>(
805  loc, sourceVector, sliceOffset, sliceShape, strides);
806 
807  valueMapping[op] = newOp;
808  return success();
809 }
810 
811 static void convertContractOp(vector::ContractionOp op,
812  llvm::DenseMap<Value, Value> &valueMapping) {
813  OpBuilder b(op);
814  Value opA = valueMapping.find(op.getLhs())->second;
815  Value opB = valueMapping.find(op.getRhs())->second;
816  Value opC = valueMapping.find(op.getAcc())->second;
817  Value matmul = b.create<gpu::SubgroupMmaComputeOp>(
818  op.getLoc(), opC.getType(), opA, opB, opC, /*a_transpose=*/UnitAttr(),
819  /*b_transpose=*/UnitAttr());
820  valueMapping[op.getResult()] = matmul;
821 }
822 
823 static LogicalResult
824 convertContractOpToMmaSync(vector::ContractionOp op,
825  llvm::DenseMap<Value, Value> &valueMapping) {
826  OpBuilder b(op);
827  Value opA = valueMapping.find(op.getLhs())->second;
828  Value opB = valueMapping.find(op.getRhs())->second;
829  Value opC = valueMapping.find(op.getAcc())->second;
830  int64_t m = op.getLhs().getType().cast<VectorType>().getShape()[0];
831  int64_t n = op.getRhs().getType().cast<VectorType>().getShape()[0];
832  int64_t k = op.getLhs().getType().cast<VectorType>().getShape()[1];
833  Value matmul = b.create<nvgpu::MmaSyncOp>(op.getLoc(), opA, opB, opC,
834  b.getI64ArrayAttr({m, n, k}));
835  valueMapping[op.getResult()] = matmul;
836  return success();
837 }
838 
839 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
840 static void convertConstantOp(arith::ConstantOp op,
841  llvm::DenseMap<Value, Value> &valueMapping) {
842  assert(constantSupportsMMAMatrixType(op));
843  OpBuilder b(op);
844  auto splat =
845  op.getValue().cast<SplatElementsAttr>().getSplatValue<TypedAttr>();
846  auto scalarConstant =
847  b.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat);
848  const char *fragType = inferFragType(op);
849  auto vecType = op.getType().cast<VectorType>();
851  vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
852  auto matrix = b.create<gpu::SubgroupMmaConstantMatrixOp>(op.getLoc(), type,
853  scalarConstant);
854  valueMapping[op.getResult()] = matrix;
855 }
856 
857 /// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
858 static void convertBroadcastOp(vector::BroadcastOp op,
859  llvm::DenseMap<Value, Value> &valueMapping) {
860  assert(broadcastSupportsMMAMatrixType(op));
861  OpBuilder b(op);
862  const char *fragType = inferFragType(op);
863  auto vecType = op.getVectorType();
865  vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
866  auto matrix = b.create<gpu::SubgroupMmaConstantMatrixOp>(op.getLoc(), type,
867  op.getSource());
868  valueMapping[op.getResult()] = matrix;
869 }
870 
871 // Replace ForOp with a new ForOp with extra operands. The YieldOp is not
872 // updated and needs to be updated separatly for the loop to be correct.
873 static scf::ForOp replaceForOpWithNewSignature(OpBuilder &b, scf::ForOp loop,
874  ValueRange newIterOperands) {
875  // Create a new loop before the existing one, with the extra operands.
877  b.setInsertionPoint(loop);
878  auto operands = llvm::to_vector<4>(loop.getIterOperands());
879  operands.append(newIterOperands.begin(), newIterOperands.end());
880  scf::ForOp newLoop =
881  b.create<scf::ForOp>(loop.getLoc(), loop.getLowerBound(),
882  loop.getUpperBound(), loop.getStep(), operands);
883  newLoop.getBody()->erase();
884  newLoop.getLoopBody().getBlocks().splice(
885  newLoop.getLoopBody().getBlocks().begin(),
886  loop.getLoopBody().getBlocks());
887  for (Value operand : newIterOperands)
888  newLoop.getBody()->addArgument(operand.getType(), operand.getLoc());
889 
890  for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
891  loop.getNumResults())))
892  std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
893  loop.erase();
894  return newLoop;
895 }
896 
897 static void convertForOp(scf::ForOp op,
898  llvm::DenseMap<Value, Value> &valueMapping) {
899  SmallVector<Value> newOperands;
901  for (const auto &operand : llvm::enumerate(op.getIterOperands())) {
902  auto it = valueMapping.find(operand.value());
903  if (it == valueMapping.end())
904  continue;
905  argMapping.push_back(std::make_pair(
906  operand.index(), op.getNumIterOperands() + newOperands.size()));
907  newOperands.push_back(it->second);
908  }
909  OpBuilder b(op);
910  scf::ForOp newForOp = replaceForOpWithNewSignature(b, op, newOperands);
911  Block &loopBody = *newForOp.getBody();
912  for (auto mapping : argMapping) {
913  valueMapping[newForOp.getResult(mapping.first)] =
914  newForOp.getResult(mapping.second);
915  valueMapping[loopBody.getArgument(mapping.first +
916  newForOp.getNumInductionVars())] =
917  loopBody.getArgument(mapping.second + newForOp.getNumInductionVars());
918  }
919 }
920 
921 static void convertYieldOp(scf::YieldOp op,
922  llvm::DenseMap<Value, Value> &valueMapping) {
923  OpBuilder b(op);
924  auto loop = cast<scf::ForOp>(op->getParentOp());
925  auto yieldOperands = llvm::to_vector<4>(op.getOperands());
926  for (const auto &operand : llvm::enumerate(op.getOperands())) {
927  auto it = valueMapping.find(operand.value());
928  if (it == valueMapping.end())
929  continue;
930  // Replace the yield of old value with the for op argument to make it easier
931  // to remove the dead code.
932  yieldOperands[operand.index()] = loop.getIterOperands()[operand.index()];
933  yieldOperands.push_back(it->second);
934  }
935  b.create<scf::YieldOp>(op.getLoc(), yieldOperands);
936  op.erase();
937 }
938 
939 /// Convert an elementwise op to the equivalent elementwise op on MMA matrix.
940 static void convertElementwiseOp(Operation *op, gpu::MMAElementwiseOp opType,
941  llvm::DenseMap<Value, Value> &valueMapping) {
942  OpBuilder b(op);
943  SmallVector<Value> matrixOperands;
944  for (Value operand : op->getOperands())
945  matrixOperands.push_back(valueMapping.find(operand)->second);
946  Value newOp = b.create<gpu::SubgroupMmaElementwiseOp>(
947  op->getLoc(), matrixOperands[0].getType(), matrixOperands, opType);
948  valueMapping[op->getResult(0)] = newOp;
949 }
950 
952  bool useNvGpu) {
953  if (!useNvGpu) {
954  patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>(
955  patterns.getContext());
956  return;
957  }
958  patterns
959  .add<nvgpu::PrepareContractToGPUMMASync, CombineTransferReadOpTranspose>(
960  patterns.getContext());
961 }
962 
964  SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/false);
965  llvm::DenseMap<Value, Value> valueMapping;
966  for (Operation *op : ops) {
967  if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
968  convertTransferReadOp(transferRead, valueMapping);
969  } else if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) {
970  convertTransferWriteOp(transferWrite, valueMapping);
971  } else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) {
972  convertContractOp(contractOp, valueMapping);
973  } else if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
974  convertConstantOp(constantOp, valueMapping);
975  } else if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) {
976  convertBroadcastOp(broadcastOp, valueMapping);
977  } else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
978  convertForOp(forOp, valueMapping);
979  } else if (auto yiledOp = dyn_cast<scf::YieldOp>(op)) {
980  convertYieldOp(yiledOp, valueMapping);
981  } else if (auto elementwiseType = convertElementwiseOpToMMA(op)) {
982  convertElementwiseOp(op, *elementwiseType, valueMapping);
983  }
984  }
985 }
986 
988  SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/true);
989  llvm::DenseMap<Value, Value> valueMapping;
990  for (Operation *op : ops) {
992  .Case([&](vector::TransferReadOp transferReadOp) {
993  return convertTransferReadToLoads(transferReadOp, valueMapping);
994  })
995  .Case([&](vector::TransferWriteOp transferWriteOp) {
996  return convertTransferWriteToStores(transferWriteOp,
997  valueMapping);
998  })
999  .Case([&](vector::ExtractStridedSliceOp extractStridedSliceOp) {
1000  return convertExtractStridedSlice(extractStridedSliceOp,
1001  valueMapping);
1002  })
1003  .Case([&](vector::ContractionOp contractionOp) {
1004  return convertContractOpToMmaSync(contractionOp, valueMapping);
1005  })
1006  .Case([&](scf::ForOp forOp) {
1007  convertForOp(forOp, valueMapping);
1008  return success();
1009  })
1010  .Case([&](scf::YieldOp yieldOp) {
1011  convertYieldOp(yieldOp, valueMapping);
1012  return success();
1013  })
1014  .Case([&](arith::ConstantOp constOp) {
1015  return convertConstantOpMmaSync(constOp, valueMapping);
1016  })
1017  .Default([&](Operation *op) {
1018  op->emitError() << "unhandled vector to mma type: " << *op;
1019  return failure();
1020  })
1021  .failed()) {
1022  op->emitError() << "Failed to convert op " << *op;
1023  return failure();
1024  }
1025  }
1026  return success();
1027 }
1028 
1029 namespace {
1030 
1031 struct ConvertVectorToGPUPass
1032  : public impl::ConvertVectorToGPUBase<ConvertVectorToGPUPass> {
1033 
1034  explicit ConvertVectorToGPUPass(bool useNvGpu_) {
1035  useNvGpu.setValue(useNvGpu_);
1036  }
1037 
1038  void runOnOperation() override {
1039  RewritePatternSet patterns(&getContext());
1040  populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue());
1041  if (failed(
1042  applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
1043  return signalPassFailure();
1044 
1045  if (useNvGpu.getValue()) {
1046  if (failed(convertVectorToNVVMCompatibleMMASync(getOperation())))
1047  return signalPassFailure();
1048  }
1049 
1050  (void)convertVectorToMMAOps(getOperation());
1051  }
1052 };
1053 
1054 } // namespace
1055 
1056 std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass(bool useNvGpu) {
1057  return std::make_unique<ConvertVectorToGPUPass>(useNvGpu);
1058 }
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
static LogicalResult convertExtractStridedSlice(vector::ExtractStridedSliceOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult createNonLdMatrixLoads(vector::TransferReadOp op, OpBuilder &builder, llvm::DenseMap< Value, Value > &valueMapping)
static void convertBroadcastOp(vector::BroadcastOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
static std::optional< int64_t > getMemrefConstantHorizontalStride(ShapedType type)
static void convertConstantOp(arith::ConstantOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static void getXferIndices(OpBuilder &b, TransferOpType xferOp, AffineMap offsetMap, ArrayRef< Value > dimValues, SmallVector< Value, 4 > &indices)
For a vector TransferOpType xferOp, an empty indices vector, and an AffineMap representing offsets to...
Definition: VectorToGPU.cpp:48
static scf::ForOp replaceForOpWithNewSignature(OpBuilder &b, scf::ForOp loop, ValueRange newIterOperands)
static VectorType getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo &regInfo)
Returns the vector type which represents a matrix fragment.
static bool supportsMMaMatrixType(Operation *op, bool useNvGpu)
static LogicalResult convertTransferReadToLoads(vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
Converts a vector.transfer_read operation directly to either a vector.load or a nvgpu....
static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp)
Return true if the constant is a splat to a 2D vector so that it can be converted to a MMA constant m...
static bool contractSupportsMMAMatrixType(vector::ContractionOp contract, bool useNvGpu)
Definition: VectorToGPU.cpp:68
static void convertTransferReadOp(vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
static LogicalResult creatLdMatrixCompatibleLoads(vector::TransferReadOp op, OpBuilder &builder, llvm::DenseMap< Value, Value > &valueMapping)
static void convertContractOp(vector::ContractionOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp)
static void convertTransferWriteOp(vector::TransferWriteOp op, llvm::DenseMap< Value, Value > &valueMapping)
static std::optional< gpu::MMAElementwiseOp > convertElementwiseOpToMMA(Operation *op)
Return the MMA elementwise enum associated with op if it is supported.
static void convertElementwiseOp(Operation *op, gpu::MMAElementwiseOp opType, llvm::DenseMap< Value, Value > &valueMapping)
Convert an elementwise op to the equivalent elementwise op on MMA matrix.
static bool extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op)
Returns true if the extract strided slice op is supported with mma.sync path.
static SetVector< Operation * > getSliceContract(Operation *op, TransitiveFilter backwardFilter, TransitiveFilter forwardFilter)
Return an unsorted slice handling scf.for region differently than getSlice.
static bool isTransposeMatrixLoadMap(OpBuilder &b, AffineMap permutationMap)
Definition: VectorToGPU.cpp:97
static bool elementwiseSupportsMMAMatrixType(Operation *op)
Return true if the op is supported as elementwise op on MMAMatrix type.
static SetVector< Operation * > getOpToConvert(mlir::Operation *op, bool useNvGpu)
static LogicalResult convertTransferWriteToStores(vector::TransferWriteOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertConstantOpMmaSync(arith::ConstantOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static void convertForOp(scf::ForOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertContractOpToMmaSync(vector::ContractionOp op, llvm::DenseMap< Value, Value > &valueMapping)
static const char * inferFragType(OpTy op)
static void convertYieldOp(scf::YieldOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp, bool useNvGpu)
static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp)
Return true if this is a broadcast from scalar to a 2D vector.
An integer constant appearing in affine expression.
Definition: AffineExpr.h:232
int64_t getValue() const
Definition: AffineExpr.cpp:505
Base type for affine expression.
Definition: AffineExpr.h:68
U cast() const
Definition: AffineExpr.h:291
constexpr bool isa() const
Definition: AffineExpr.h:270
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:42
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
Definition: AffineMap.cpp:110
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
Definition: AffineMap.cpp:306
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:236
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:323
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:206
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:455
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:296
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:118
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:109
UnitAttr getUnitAttr()
Definition: Builders.cpp:99
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:335
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:327
MLIRContext * getContext() const
Definition: Builders.h:54
Attribute getZeroAttr(Type type)
Definition: Builders.cpp:306
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:262
IndexType getIndexType()
Definition: Builders.cpp:56
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:300
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:64
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:300
This class helps build Operations.
Definition: Builders.h:198
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:350
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:422
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:324
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:154
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:225
operand_type_range getOperandTypes()
Definition: Operation.h:314
result_type_range getResultTypes()
Definition: Operation.h:345
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:295
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:654
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:578
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:418
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:610
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
An attribute that represents a reference to a splat vector or tensor constant, meaning all of the ele...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U dyn_cast() const
Definition: Types.h:270
bool isa() const
Definition: Types.h:260
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:349
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
Type getType() const
Return the type of this value.
Definition: Value.h:114
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition: GPUDialect.h:125
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Definition: GPUDialect.cpp:52
unsigned getMemorySpaceAsInt(Attribute memorySpace)
[deprecated] Returns the memory space in old raw integer representation.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:230
FailureOr< AffineMap > getLaneIdAndValueIdToOperandCoord(Location loc, OpBuilder &builder, const WarpMatrixInfo &fragmentType)
Returns an AffineMap which maps a two dimensions representing (laneId, logicalValueId) and returns tw...
Definition: MMAUtils.cpp:173
FailureOr< AffineMap > getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder, const LdMatrixParams &params)
Returns an AffineMap which maps a single dimension representing the laneId to two results representin...
Definition: MMAUtils.cpp:238
int64_t inferTileWidthInBits(const WarpMatrixInfo &type)
Returns the number of bits in a single tile row.
Definition: MMAUtils.cpp:87
FailureOr< vector::ContractionOp > getUserContract(Operation *op)
Returns the first user of the op that is vector.contract.
Definition: MMAUtils.cpp:50
FailureOr< WarpMatrixInfo > getWarpMatrixInfo(Operation *op)
If op is a vector.transfer_write, return the WarpMatrixInfo for the vector operand.
Definition: MMAUtils.cpp:58
FailureOr< LdMatrixParams > getLdMatrixParams(const WarpMatrixInfo &type, bool transpose)
Given type that contains info for a warp-matrix operand and whether or not the load is a transposed l...
Definition: MMAUtils.cpp:209
FailureOr< FragmentElementInfo > getMmaSyncRegisterType(const WarpMatrixInfo &type)
Returns a FragmentElementInfo struct describing the register types for the given matrix fragment type...
Definition: MMAUtils.cpp:100
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Definition: VectorOps.h:197
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Definition: VectorOps.h:192
Include the generated interface declarations.
LogicalResult convertVectorToNVVMCompatibleMMASync(Operation *rootOp)
Convert vector ops ops nested under rootOp to vector and GPU operaitons compatible with the nvvm....
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, bool useNvGpu=false)
Patterns to transform vector ops into a canonical form to convert to MMA matrix operations.
void convertVectorToMMAOps(Operation *rootOp)
Convert vector ops to MMA matrix operations nested under rootOp.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:336
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ValueRange operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:964
void getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, TransitiveFilter filter=nullptr)
Fills backwardSlice with the computed backward slice (i.e.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult applyPatternsAndFoldGreedily(MutableArrayRef< Region > regions, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig())
Rewrite the regions of the specified operation, which must be isolated from above,...
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, TransitiveFilter filter=nullptr)
Fills forwardSlice with the computed forward slice (i.e.
std::unique_ptr< Pass > createConvertVectorToGPUPass(bool useNvGpu=false)
Convert from vector to GPU ops.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Multi-root DAG topological sort.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:356
Specifies information about the registers which compose a matrix fragment according to the PTX docume...
Definition: MMAUtils.h:52
Transform vector.contract into (m,k)x(n,k)x(m,n) form so that it can be converted to nvgpu....
Definition: MMAUtils.h:101