MLIR  19.0.0git
NVGPUTransformOps.cpp
Go to the documentation of this file.
1 //===- NVGPUTransformOps.cpp - Implementation of NVGPU transform ops ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
29 #include "mlir/IR/AffineExpr.h"
30 #include "mlir/IR/BuiltinTypes.h"
31 #include "mlir/IR/Value.h"
32 #include "llvm/ADT/ArrayRef.h"
33 
34 using namespace mlir;
35 using namespace mlir::linalg;
36 using namespace mlir::nvgpu;
37 using namespace mlir::NVVM;
38 using namespace mlir::transform;
39 
40 #define DEBUG_TYPE "nvgpu-transforms"
41 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
42 #define DBGSNL() (llvm::dbgs() << "\n")
43 #define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
44 
45 //===----------------------------------------------------------------------===//
46 // Apply...ConversionPatternsOp
47 //===----------------------------------------------------------------------===//
48 
49 void transform::ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns(
50  TypeConverter &typeConverter, RewritePatternSet &patterns) {
51  auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
52  /// device-side async tokens cannot be materialized in nvvm. We just
53  /// convert them to a dummy i32 type in order to easily drop them during
54  /// conversion.
56  llvmTypeConverter, [](gpu::AddressSpace space) -> unsigned {
57  switch (space) {
58  case gpu::AddressSpace::Global:
59  return static_cast<unsigned>(
61  case gpu::AddressSpace::Workgroup:
62  return static_cast<unsigned>(
64  case gpu::AddressSpace::Private:
65  return 0;
66  }
67  llvm_unreachable("unknown address space enum value");
68  return 0;
69  });
70  llvmTypeConverter.addConversion(
71  [&](nvgpu::DeviceAsyncTokenType type) -> Type {
72  return llvmTypeConverter.convertType(
73  IntegerType::get(type.getContext(), 32));
74  });
75  llvmTypeConverter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type {
76  return llvmTypeConverter.convertType(
77  IntegerType::get(type.getContext(), 64));
78  });
79  llvmTypeConverter.addConversion(
80  [&](nvgpu::WarpgroupAccumulatorType type) -> Type {
81  Type elemType = type.getFragmented().getElementType();
82  int64_t sizeM = type.getFragmented().getDimSize(0);
83  int64_t sizeN = type.getFragmented().getDimSize(1);
84 
85  unsigned numMembers;
86  if (elemType.isF32() || elemType.isInteger(32))
87  numMembers = sizeN / 2;
88  else if (elemType.isF16())
89  numMembers = sizeN / 4;
90  else
91  llvm_unreachable("unsupported type for warpgroup accumulator");
92 
93  SmallVector<Type> innerStructBody;
94  for (unsigned i = 0; i < numMembers; i++)
95  innerStructBody.push_back(elemType);
96  auto innerStructType = LLVM::LLVMStructType::getLiteral(
97  type.getContext(), innerStructBody);
98 
99  SmallVector<Type> structBody;
100  for (int i = 0; i < sizeM; i += kWgmmaSizeM)
101  structBody.push_back(innerStructType);
102 
103  auto convertedType =
104  LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
105  return llvmTypeConverter.convertType(convertedType);
106  });
107  llvmTypeConverter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type {
108  return llvmTypeConverter.convertType(
109  getMBarrierMemrefType(type.getContext(), type));
110  });
111  llvmTypeConverter.addConversion(
112  [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type {
113  return llvmTypeConverter.convertType(
114  IntegerType::get(type.getContext(), 64));
115  });
116  llvmTypeConverter.addConversion(
117  [&](nvgpu::TensorMapDescriptorType type) -> Type {
118  return LLVM::LLVMPointerType::get(type.getContext());
119  });
120  populateNVGPUToNVVMConversionPatterns(llvmTypeConverter, patterns);
121 }
122 
124 transform::ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter(
125  transform::TypeConverterBuilderOpInterface builder) {
126  if (builder.getTypeConverterType() != "LLVMTypeConverter")
127  return emitOpError("expected LLVMTypeConverter");
128  return success();
129 }
130 
131 //===---------------------------------------------------------------------===//
132 // CreateAsyncGroupsOp
133 //===---------------------------------------------------------------------===//
134 
135 void transform::CreateAsyncGroupsOp::getEffects(
137  transform::consumesHandle(getTarget(), effects);
138  transform::producesHandle(getResult(), effects);
140 }
141 
142 DiagnosedSilenceableFailure transform::CreateAsyncGroupsOp::applyToOne(
143  TransformRewriter &rewriter, Operation *target,
144  ApplyToEachResultList &results, TransformState &state) {
145  nvgpu::createAsyncGroups(rewriter, target, getBypassL1());
146  results.push_back(target);
148 }
149 
150 //===----------------------------------------------------------------------===//
151 // PipelineSharedMemoryCopiesOp
152 //===----------------------------------------------------------------------===//
153 
154 /// Returns true if the given type has the default memory space.
156  return !type.getMemorySpace() || type.getMemorySpaceAsInt() == 0;
157 }
158 
159 /// Returns true if the given type has the shared (workgroup) memory space.
161  auto space =
162  dyn_cast_if_present<gpu::AddressSpaceAttr>(type.getMemorySpace());
163  return space &&
164  space.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
165 }
166 
167 /// Returns the value produced by a load from the default memory space. Returns
168 /// null if the operation is not such a load.
170  // TODO: consider an interface or leveraging the memory effects interface.
171  auto load = dyn_cast<vector::TransferReadOp>(op);
172  if (!load)
173  return nullptr;
174 
175  auto loadType = dyn_cast<MemRefType>(load.getSource().getType());
176  if (!loadType || !hasDefaultMemorySpace(loadType))
177  return nullptr;
178  return load;
179 }
180 
181 /// Returns true if the operation is storing the given value into shared memory.
182 static bool isStoreToShared(Operation *op, Value v) {
183  // TOD: consider an interface or leveraging the memory effects interface.
184  auto store = dyn_cast<vector::TransferWriteOp>(op);
185  if (!store || store.getVector() != v)
186  return false;
187 
188  auto storeType = dyn_cast<MemRefType>(store.getSource().getType());
189  return storeType || hasSharedMemorySpace(storeType);
190 }
191 
192 /// Returns true if the operation is a load from the default memory space the
193 /// result of which is only stored into the shared memory space.
195  Value loaded = getValueLoadedFromGlobal(op);
196  if (!loaded || !loaded.hasOneUse())
197  return false;
198 
199  return isStoreToShared(*loaded.getUsers().begin(), loaded);
200 }
201 
202 /// Populate `ops` with the set of operations that belong to the stage 0 of the
203 /// pipelined version of the given loop when pipelining copies to shared memory.
204 /// Specifically, this collects:
205 ///
206 /// 1. all loads from global memory, both sync and async;
207 /// 2. the barriers for async loads.
208 ///
209 /// In particular, barriers are omitted if they do not dominate at least one
210 /// async load for which there is not yet a barrier.
211 static LogicalResult
212 collectStage0PipeliningOps(scf::ForOp forOp,
214 
216  for (Operation &op : *forOp.getBody()) {
217  // Bail on nested ops for now.
218  if (op.getNumRegions() > 0)
219  return failure();
220 
221  if (isa<gpu::BarrierOp>(op)) {
222  barriers.insert(&op);
223  continue;
224  }
225 
226  if (isa<nvgpu::DeviceAsyncCopyOp, nvgpu::DeviceAsyncCreateGroupOp>(op)) {
227  ops.insert(&op);
228  ops.insert(std::make_move_iterator(barriers.begin()),
229  std::make_move_iterator(barriers.end()));
230  assert(barriers.empty() &&
231  "expected to have moved the barriers into another set");
232  continue;
233  }
234 
236  ops.insert(&op);
237  continue;
238  }
239  }
240 
241  return success();
242 }
243 
244 /// Hook for the loop pipeliner that sets the "num groups in flight" attribute
245 /// of async wait operations corresponding to pipelined shared memory copies.
246 // TODO: this currently assumes that there are no groups that could be in flight
247 // in the existing code.
248 static void
251  unsigned iteration, unsigned depth) {
252  // Based on the order of copies within the loop we need to set the number
253  // of copies in flight, unless it is already set.
254  auto waitOp = dyn_cast<nvgpu::DeviceAsyncWaitOp>(op);
255  if (!waitOp || waitOp.getNumGroups())
256  return;
257 
258  int numGroupInFlight = 0;
261  numGroupInFlight = depth - 1;
262  } else {
263  // By construction there should be no wait op in the prologue as all the
264  // wait should be in the last stage.
266  // Based on the schedule we pick we know how many groups are in flight for
267  // each iteration of the epilogue.
268  numGroupInFlight = depth - 1 - iteration;
269  }
270  waitOp.setNumGroups(numGroupInFlight);
271 }
272 
273 /// Hook for the loop pipeliner that populates `ops` with the stage information
274 /// as follows:
275 ///
276 /// - operations in `stage0Ops` (typically loads from global memory and
277 /// related barriers) are at stage 0;
278 /// - operations in the backward slice of any stage0Ops are all at stage 0;
279 /// - other operations are at stage `depth`;
280 /// - the internal order of the pipelined loop has ops at stage `depth` first,
281 /// then those at stage 0, with relative order within each group preserved.
282 ///
283 static void getPipelineStages(
284  scf::ForOp forOp,
285  std::vector<std::pair<Operation *, unsigned>> &opsWithPipelineStages,
286  unsigned depth, llvm::SmallPtrSetImpl<Operation *> &stage0Ops) {
287  SetVector<Operation *> dependencies;
288  BackwardSliceOptions options([&](Operation *visited) {
289  return visited->getBlock() == forOp.getBody();
290  });
291  options.inclusive = true;
292  for (Operation &op : forOp.getBody()->getOperations()) {
293  if (stage0Ops.contains(&op))
294  getBackwardSlice(&op, &dependencies, options);
295  }
296 
297  for (Operation &op : forOp.getBody()->getOperations()) {
298  if (!dependencies.contains(&op) && !isa<scf::YieldOp>(op))
299  opsWithPipelineStages.emplace_back(&op, depth);
300  }
301  for (Operation &op : forOp.getBody()->getOperations()) {
302  if (dependencies.contains(&op))
303  opsWithPipelineStages.emplace_back(&op, 0);
304  }
305 }
306 
307 /// Hook for the loop pipeliner. Replaces op with a predicated version and
308 /// returns the resulting operation. Returns the original op if the predication
309 /// isn't necessary for the given op. Returns null if predication is needed but
310 /// not supported.
312  Operation *op, Value predicate) {
313  // Some operations may be fine to execute "speculatively" more times than the
314  // original number of iterations, in particular side-effect free operations
315  // and barriers, even if they cannot be predicated.
316  if (isMemoryEffectFree(op) ||
317  isa<gpu::BarrierOp, nvgpu::DeviceAsyncCreateGroupOp,
318  nvgpu::DeviceAsyncWaitOp>(op)) {
319  return op;
320  }
321 
322  // Otherwise, only async copies can currently be predicated.
323  auto asyncCopyOp = dyn_cast<nvgpu::DeviceAsyncCopyOp>(op);
324  if (!asyncCopyOp)
325  return nullptr;
326 
327  // Create srcElement Value based on `predicate`. The next lines generate
328  // the following code:
329  //
330  // srcElement = (pred) ? prevSrcElements : 0;
331  //
332  Location loc = asyncCopyOp->getLoc();
333  Value dstElements =
334  rewriter.create<arith::ConstantOp>(loc, asyncCopyOp.getDstElementsAttr());
335  Value originalSrcElement =
336  asyncCopyOp.getSrcElements() ? asyncCopyOp.getSrcElements() : dstElements;
337  Value c0Index = rewriter.create<arith::ConstantIndexOp>(loc, 0);
338  auto srcElements = rewriter.create<arith::SelectOp>(
339  loc, predicate, originalSrcElement, c0Index);
340  auto asyncCopyZeroFillOp = rewriter.create<nvgpu::DeviceAsyncCopyOp>(
341  loc, nvgpu::DeviceAsyncTokenType::get(asyncCopyOp.getContext()),
342  asyncCopyOp.getDst(), asyncCopyOp.getDstIndices(), asyncCopyOp.getSrc(),
343  asyncCopyOp.getSrcIndices(), asyncCopyOp.getDstElements(), srcElements,
344  UnitAttr());
345  rewriter.replaceOp(asyncCopyOp, asyncCopyZeroFillOp);
346  return asyncCopyZeroFillOp;
347 }
348 
349 /// Applies loop pipelining with the given depth to the given loop so that
350 /// copies into the shared memory are pipelined. Doesn't affect other loops.
351 /// Returns a pair containing the error state and the pipelined op, the latter
352 /// being null in case of any failure. The error state contains a definite error
353 /// if the IR has been modified and a silenceable error otherwise.
354 static std::tuple<DiagnosedSilenceableFailure, scf::ForOp>
355 pipelineForSharedCopies(RewriterBase &rewriter, scf::ForOp forOp, int64_t depth,
356  bool epiloguePeeling) {
358  if (failed(collectStage0PipeliningOps(forOp, stage0Ops))) {
359  return std::make_tuple(
360  emitSilenceableFailure(forOp, "cannot find stage 0 ops for pipelining"),
361  scf::ForOp());
362  }
363  if (stage0Ops.empty()) {
364  return std::make_tuple(
365  emitSilenceableFailure(forOp, "no shared memory copy"), scf::ForOp());
366  }
367 
369  unsigned maxDepth = depth;
370  auto setAnnotation = [&](Operation *op,
372  unsigned iteration) {
373  return setAsyncWaitGroupsInFlight(rewriter, op, part, iteration, maxDepth);
374  };
375  options.getScheduleFn =
376  [&](scf::ForOp schedulingFor,
377  std::vector<std::pair<Operation *, unsigned>> &ops) {
378  if (schedulingFor != forOp)
379  return;
380  return getPipelineStages(forOp, ops, maxDepth, stage0Ops);
381  };
382  options.annotateFn = setAnnotation;
383  if (!epiloguePeeling) {
384  options.peelEpilogue = false;
385  options.predicateFn = replaceOpWithPredicatedOp;
386  }
387 
388  OpBuilder::InsertionGuard guard(rewriter);
389  rewriter.setInsertionPoint(forOp);
390  bool modifiedIR;
391  FailureOr<scf::ForOp> maybePipelined =
392  pipelineForLoop(rewriter, forOp, options, &modifiedIR);
393  if (succeeded(maybePipelined)) {
394  return std::make_tuple(DiagnosedSilenceableFailure::success(),
395  *maybePipelined);
396  }
397  return std::make_tuple(
398  modifiedIR
400  : emitSilenceableFailure(forOp, "pipelining preconditions failed"),
401  scf::ForOp());
402 }
403 
404 DiagnosedSilenceableFailure PipelineSharedMemoryCopiesOp::applyToOne(
405  TransformRewriter &rewriter, scf::ForOp forOp,
406  ApplyToEachResultList &results, TransformState &state) {
407  auto [diag, pipelined] = pipelineForSharedCopies(
408  rewriter, forOp, static_cast<int64_t>(getDepth()), getPeelEpilogue());
409  if (diag.succeeded()) {
410  results.push_back(pipelined);
412  }
413  if (diag.isDefiniteFailure()) {
414  auto diag = emitDefiniteFailure("irreversible pipelining failure");
415  if (!getPeelEpilogue()) {
416  diag.attachNote(forOp->getLoc()) << "couldn't predicate?";
417  diag.attachNote(getLoc()) << "try setting " << getPeelEpilogueAttrName();
418  }
419  return diag;
420  }
421 
422  return std::move(diag);
423 }
424 
425 //===----------------------------------------------------------------------===//
426 // RewriteMatmulAsMmaSyncOp
427 //===----------------------------------------------------------------------===//
428 
429 /// Helper struct to encode a pair of row/column indexings in the form of
430 /// affine expressions.
431 struct RowColIndexing : private std::pair<AffineExpr, AffineExpr> {
433  : std::pair<AffineExpr, AffineExpr>(row, col) {}
434 
435  AffineExpr row() const { return first; };
436  AffineExpr col() const { return second; };
437 
438  void print(llvm::raw_ostream &os) const {
439  os << "- indexing: " << first << ", " << second;
440  }
441 };
442 
443 /// Helper struct to provide a simple mapping from matmul operations to the
444 /// corresponding mma.sync operation. This is constrained to the case where the
445 /// matmul matches the mma.sync operation 1-1.
448  : b(b), loc(loc), laneId(laneId) {}
449 
451  std::function<SmallVector<RowColIndexing>(MLIRContext *)>;
452 
453  /// Create the mma.sync operation corresponding to `linalgOp` along with all
454  /// the supporting load/store and vector operations.
455  FailureOr<Operation *> buildMmaSync(LinalgOp linalgOp);
456 
457 private:
458  struct MmaSyncInfo {
459  std::tuple<IndexCalculator, IndexCalculator, IndexCalculator> indexFns;
460  std::tuple<SmallVector<int64_t>, SmallVector<int64_t>, SmallVector<int64_t>>
461  vectorShapes;
462  SmallVector<int64_t> mmaShape;
463  bool tf32Enabled;
464  };
465 
466  /// Return the specific index calculator for the given `linalgOp` or failure
467  /// if the op is not supported. This is the toplevel switch that should just
468  /// be Tablegen'd in the future.
469  FailureOr<MmaSyncInfo> getIndexCalculators(ArrayRef<int64_t> opShape,
470  TypeRange elementalTypes);
471 
472  //===--------------------------------------------------------------------===//
473  // Instruction-specific row, column indexing expression builders.
474  // These should all be declaratively specified via Tablegen in the future.
475  // The Tablegen specification should be as straightforward as possible to
476  // only model the existing size and type combinations.
477  //===--------------------------------------------------------------------===//
478  //
479  // TODO: Tablegen all this.
480  //===--------------------------------------------------------------------===//
481  // m16n8k4 tf32 case.
482  //===--------------------------------------------------------------------===//
483  /// From the NVIDIA doc:
484  /// groupID = %laneid >> 2
485  /// threadIDInGroup = %laneid % 4
486  /// row = groupID for a0
487  /// groupID + 8 for a1
488  /// col = threadIDInGroup
489  static SmallVector<RowColIndexing> m16n8k4tf32Lhs(MLIRContext *ctx) {
490  auto dim = getAffineDimExpr(0, ctx);
491  AffineExpr groupID = dim.floorDiv(4);
492  AffineExpr threadIDInGroup = dim % 4;
493  return {RowColIndexing{groupID, threadIDInGroup},
494  RowColIndexing{groupID + 8, threadIDInGroup}};
495  }
496 
497  /// From the NVIDIA doc:
498  /// groupID = %laneid >> 2
499  /// threadIDInGroup = %laneid % 4
500  /// row = threadIDInGroup
501  /// col = groupID
502  static SmallVector<RowColIndexing> m16n8k4tf32Rhs(MLIRContext *ctx) {
503  auto dim = getAffineDimExpr(0, ctx);
504  AffineExpr groupID = dim.floorDiv(4);
505  AffineExpr threadIDInGroup = dim % 4;
506  return {RowColIndexing{threadIDInGroup, groupID}};
507  }
508 
509  /// From the NVIDIA doc:
510  /// groupID = %laneid >> 2
511  /// threadIDInGroup = %laneid % 4
512  /// row = groupID for c0 and c1
513  /// groupID + 8 for c2 and c3
514  /// col = (threadIDInGroup * 2) + (i & 0x1) for ci where i = {0,..,3}
515  static SmallVector<RowColIndexing> m16n8k4tf32Res(MLIRContext *ctx) {
516  auto dim = getAffineDimExpr(0, ctx);
517  AffineExpr groupID = dim.floorDiv(4);
518  AffineExpr threadIDInGroup = dim % 4;
519  return {RowColIndexing{groupID, threadIDInGroup * 2 + 0},
520  RowColIndexing{groupID, threadIDInGroup * 2 + 1},
521  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0},
522  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}};
523  }
524 
525  //===--------------------------------------------------------------------===//
526  // m16n8k16 f16 case.
527  //===--------------------------------------------------------------------===//
528  /// From the NVIDIA doc:
529  /// groupID = %laneid >> 2
530  /// threadIDInGroup = %laneid % 4
531  ///
532  /// row = groupID for ai where 0 <= i < 2 || 4 <= i < 6
533  /// groupID + 8 Otherwise
534  ///
535  /// col = (threadIDInGroup * 2) + (i & 0x1) for ai where i < 4
536  /// (threadIDInGroup * 2) + (i & 0x1) + 8 for ai where i >= 4
537  static SmallVector<RowColIndexing> m16n8k16f16Lhs(MLIRContext *ctx) {
538  auto dim = getAffineDimExpr(0, ctx);
539  AffineExpr groupID = dim.floorDiv(4);
540  AffineExpr threadIDInGroup = dim % 4;
541  // clang-format off
542  return {
543  RowColIndexing{groupID, threadIDInGroup * 2 + 0}, // i == 0
544  RowColIndexing{groupID, threadIDInGroup * 2 + 1}, // i == 1
545  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, // i == 2
546  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}, // i == 3
547  RowColIndexing{groupID, threadIDInGroup * 2 + 0 + 8}, // i == 4
548  RowColIndexing{groupID, threadIDInGroup * 2 + 1 + 8}, // i == 5
549  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0 + 8}, // i == 6
550  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1 + 8} // i == 7
551  };
552  // clang-format on
553  }
554 
555  /// From the NVIDIA doc:
556  /// groupID = %laneid >> 2
557  /// threadIDInGroup = %laneid % 4
558  ///
559  /// row = (threadIDInGroup * 2) + (i & 0x1) for bi where i < 2
560  /// (threadIDInGroup * 2) + (i & 0x1) + 8 for bi where i >= 2
561  ///
562  /// col = groupID
563  static SmallVector<RowColIndexing> m16n8k16f16Rhs(MLIRContext *ctx) {
564  auto dim = getAffineDimExpr(0, ctx);
565  AffineExpr groupID = dim.floorDiv(4);
566  AffineExpr threadIDInGroup = dim % 4;
567  // clang-format off
568  return {
569  RowColIndexing{threadIDInGroup * 2 + 0, groupID}, // i == 0
570  RowColIndexing{threadIDInGroup * 2 + 1, groupID}, // i == 1
571  RowColIndexing{threadIDInGroup * 2 + 0 + 8, groupID}, // i == 2
572  RowColIndexing{threadIDInGroup * 2 + 1 + 8, groupID} // i == 3
573  };
574  // clang-format on
575  }
576 
577  /// From the NVIDIA doc:
578  /// groupID = %laneid >> 2
579  /// threadIDInGroup = %laneid % 4
580  ///
581  /// row = groupID for ci where i < 2
582  /// groupID + 8 for ci where i >= 2
583  ///
584  /// col = (threadIDInGroup * 2) + (i & 0x1) for ci where i = {0,..,3}
585  static SmallVector<RowColIndexing> m16n8k16f16Res(MLIRContext *ctx) {
586  auto dim = getAffineDimExpr(0, ctx);
587  AffineExpr groupID = dim.floorDiv(4);
588  AffineExpr threadIDInGroup = dim % 4;
589  // clang-format off
590  return {
591  RowColIndexing{groupID, threadIDInGroup * 2 + 0}, // i == 0
592  RowColIndexing{groupID, threadIDInGroup * 2 + 1}, // i == 1
593  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, // i == 2
594  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1} // i == 3
595  };
596  // clang-format on
597  }
598 
599  //===--------------------------------------------------------------------===//
600  /// Helper functions to create customizable load and stores operations. The
601  /// specific shapes of each MMA instruction are passed via the
602  /// IndexCalculator callback.
603  //===--------------------------------------------------------------------===//
604  /// Build a list of memref.load operations indexed at `(row, col)` indices
605  /// that make sense for a particular MMA instruction and specified via the
606  /// IndexCalculator callback.
607  SmallVector<Value> buildMemRefLoads(OpBuilder &b, Location loc,
608  OpFoldResult laneId, Value memref,
609  const IndexCalculator &indexFn);
610 
611  /// Perform a distributed load of a vector operand of `vectorShape` for a
612  /// particular MMA instruction whose `(row, col)` indices are specified via
613  /// the IndexCalculator callback. Each `laneId` loads the subportion of the
614  /// data that makes sense for the particular MMA operation.
615  /// The `vectorShape` matches existing NVGPU dialect op specification but
616  /// could also be flattened in the future if needed for simplification.
617  Value buildMmaSyncMemRefLoadOperand(OpBuilder &b, Location loc,
618  OpFoldResult laneId, Value memref,
619  IndexCalculator indexFn,
621 
622  /// Build a list of memref.store operations indexed at `(row, col)` indices
623  /// that make sense for a particular MMA instruction and specified via the
624  /// IndexCalculator callback.
625  SmallVector<Operation *> buildMemRefStores(OpBuilder &b, Location loc,
626  ValueRange toStore,
627  OpFoldResult laneId, Value memref,
628  const IndexCalculator &indexFn);
629 
630  /// Perform a distributed store of a vector operand of `vectorShape` for a
631  /// particular MMA instruction whose `(row, col)` indices are specified via
632  /// the IndexCalculator callback. Each `laneId` loads the subportion of the
633  /// data that makes sense for the particular MMA operation.
634  /// The `vectorShape` matches existing NVGPU dialect op specification but
635  /// could also be flattened in the future if needed for simplification.
636  SmallVector<Operation *> buildMmaSyncMemRefStoreOperand(
637  OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId,
638  Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape);
639 
640  OpBuilder &b;
641  Location loc;
642  OpFoldResult laneId;
643 };
644 
645 //===--------------------------------------------------------------------===//
646 /// Helper functions to create customizable load and stores operations. The
647 /// specific shapes of each MMA instruction are passed via the
648 /// IndexCalculator callback.
649 //===--------------------------------------------------------------------===//
650 
651 template <typename ApplyFn, typename ReduceFn>
652 static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn,
653  ReduceFn reduceFn) {
654  VectorType vectorType = cast<VectorType>(vector.getType());
655  auto vectorShape = vectorType.getShape();
656  auto strides = computeStrides(vectorShape);
657  for (int64_t idx = 0, e = vectorShape[0] * strides[0]; idx < e; ++idx) {
658  auto indices = delinearize(idx, strides);
659  reduceFn(applyFn(vector, idx, indices), idx, indices);
660  }
661 }
662 
664 MmaSyncBuilder::buildMemRefLoads(OpBuilder &b, Location loc,
665  OpFoldResult laneId, Value memref,
666  const IndexCalculator &indexFn) {
667  auto aff = [&](AffineExpr e) {
668  return affine::makeComposedFoldedAffineApply(b, loc, e, laneId);
669  };
670  SmallVector<Value> res;
671  SmallVector<RowColIndexing> indexings = indexFn(b.getContext());
672  for (auto indexing : indexings) {
673  Value row = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.row()));
674  Value col = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.col()));
675  auto load = b.create<memref::LoadOp>(loc, memref, ValueRange{row, col});
676  res.push_back(load);
677  }
678  return res;
679 }
680 
681 Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand(
682  OpBuilder &b, Location loc, OpFoldResult laneId, Value memref,
683  IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) {
684  auto loads = buildMemRefLoads(b, loc, laneId, memref, std::move(indexFn));
685 
686  Type elementType = getElementTypeOrSelf(memref.getType());
687  auto vt = VectorType::get(vectorShape, elementType);
688  Value res = b.create<vector::SplatOp>(loc, vt, loads[0]);
690  res,
691  /*applyFn=*/
692  [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
693  return loads[linearIdx];
694  },
695  /*reduceFn=*/
696  [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
697  res = b.create<vector::InsertOp>(loc, v, res, indices);
698  });
699 
700  return res;
701 }
702 
703 SmallVector<Operation *> MmaSyncBuilder::buildMemRefStores(
704  OpBuilder &b, Location loc, ValueRange toStore, OpFoldResult laneId,
705  Value memref, const IndexCalculator &indexFn) {
706  auto aff = [&](AffineExpr e) {
707  return affine::makeComposedFoldedAffineApply(b, loc, e, laneId);
708  };
710  for (auto [indexing, val] :
711  llvm::zip_equal(indexFn(b.getContext()), toStore)) {
712  Value row = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.row()));
713  Value col = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.col()));
714  Operation *store =
715  b.create<memref::StoreOp>(loc, val, memref, ValueRange{row, col});
716  res.push_back(store);
717  }
718  return res;
719 }
720 
721 SmallVector<Operation *> MmaSyncBuilder::buildMmaSyncMemRefStoreOperand(
722  OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId,
723  Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) {
724  SmallVector<Value> toStore;
725  toStore.reserve(32);
727  vectorToStore,
728  /*applyFn=*/
729  [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
730  return b.create<vector::ExtractOp>(loc, vectorToStore, indices);
731  },
732  /*reduceFn=*/
733  [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
734  toStore.push_back(v);
735  });
736  return buildMemRefStores(b, loc, toStore, laneId, memref, std::move(indexFn));
737 }
738 
739 static std::tuple<SmallVector<int64_t>, SmallVector<int64_t>,
742  ArrayRef<int64_t> res) {
743  SmallVector<int64_t> vlhs{lhs.begin(), lhs.end()};
744  SmallVector<int64_t> vrhs{rhs.begin(), rhs.end()};
745  SmallVector<int64_t> vres{res.begin(), res.end()};
746  return std::make_tuple(vlhs, vrhs, vres);
747 }
748 
750 MmaSyncBuilder::getIndexCalculators(ArrayRef<int64_t> opShape,
751  TypeRange elementalTypes) {
752  // TODO: Tablegen all this.
753  Type f16 = b.getF16Type();
754  Type f32 = b.getF32Type();
755  if (opShape == ArrayRef<int64_t>{16, 8, 4} &&
756  elementalTypes == TypeRange{f32, f32, f32}) {
757  return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k4tf32Lhs,
758  &MmaSyncBuilder::m16n8k4tf32Rhs,
759  &MmaSyncBuilder::m16n8k4tf32Res),
760  makeVectorShapes({2, 1}, {1, 1}, {2, 2}),
761  SmallVector<int64_t>{opShape.begin(), opShape.end()},
762  /*tf32Enabled=*/true};
763  }
764  // This is the version with f16 accumulation.
765  // TODO: version with f32 accumulation.
766  if (opShape == ArrayRef<int64_t>{16, 8, 16} &&
767  elementalTypes == TypeRange{f16, f16, f16}) {
768  return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k16f16Lhs,
769  &MmaSyncBuilder::m16n8k16f16Rhs,
770  &MmaSyncBuilder::m16n8k16f16Res),
771  makeVectorShapes({4, 2}, {2, 2}, {2, 2}),
772  SmallVector<int64_t>{opShape.begin(), opShape.end()},
773  /*tf32Enabled=*/false};
774  }
775  return failure();
776 }
777 
779  Value lhsMemRef = linalgOp.getDpsInputOperand(0)->get();
780  Value rhsMemRef = linalgOp.getDpsInputOperand(1)->get();
781  Value resMemRef = linalgOp.getDpsInitOperand(0)->get();
782  assert(cast<MemRefType>(lhsMemRef.getType()).getRank() == 2 &&
783  "expected lhs to be a 2D memref");
784  assert(cast<MemRefType>(rhsMemRef.getType()).getRank() == 2 &&
785  "expected rhs to be a 2D memref");
786  assert(cast<MemRefType>(resMemRef.getType()).getRank() == 2 &&
787  "expected res to be a 2D memref");
788 
789  int64_t m = cast<MemRefType>(lhsMemRef.getType()).getShape()[0];
790  int64_t n = cast<MemRefType>(rhsMemRef.getType()).getShape()[1];
791  int64_t k = cast<MemRefType>(lhsMemRef.getType()).getShape()[1];
792  Type lhsType = getElementTypeOrSelf(lhsMemRef.getType());
793  Type rhsType = getElementTypeOrSelf(rhsMemRef.getType());
794  Type resType = getElementTypeOrSelf(resMemRef.getType());
795 
796  FailureOr<MmaSyncInfo> maybeInfo =
797  getIndexCalculators({m, n, k}, {lhsType, rhsType, resType});
798  if (failed(maybeInfo))
799  return failure();
800 
801  MmaSyncInfo info = *maybeInfo;
802  auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns;
803  auto [lhsShape, rhsShape, resShape] = info.vectorShapes;
804  Value lhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, lhsMemRef,
805  lhsIndexFn, lhsShape);
806  Value rhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, rhsMemRef,
807  rhsIndexFn, rhsShape);
808  Value res = buildMmaSyncMemRefLoadOperand(b, loc, laneId, resMemRef,
809  resIndexFn, resShape);
810  res = b.create<nvgpu::MmaSyncOp>(loc, lhs, rhs, res, info.mmaShape,
811  info.tf32Enabled);
812  buildMmaSyncMemRefStoreOperand(b, loc, res, laneId, resMemRef, resIndexFn,
813  resShape);
814  return res.getDefiningOp();
815 }
816 
817 DiagnosedSilenceableFailure transform::RewriteMatmulAsMmaSyncOp::applyToOne(
818  transform::TransformRewriter &rewriter, LinalgOp linalgOp,
820  transform::TransformState &state) {
821  bool fail = true;
822  // TODO: more robust detection of matmulOp, with transposes etc.
823  if (isa_and_nonnull<linalg::MatmulOp>(linalgOp.getOperation())) {
824  Location loc = linalgOp.getLoc();
825  // TODO: more robust computation of laneId, for now assume a single warp.
826  Value laneId = rewriter.create<gpu::ThreadIdOp>(
827  loc, rewriter.getIndexType(), gpu::Dimension::x);
828  if (succeeded(MmaSyncBuilder(rewriter, loc, laneId).buildMmaSync(linalgOp)))
829  fail = false;
830  }
831 
832  if (fail) {
833  DiagnosedSilenceableFailure diag = emitSilenceableError()
834  << "unsupported target op: " << linalgOp;
835  diag.attachNote(linalgOp->getLoc()) << "target op";
836  return diag;
837  }
838 
839  rewriter.eraseOp(linalgOp);
841 }
842 
843 //===----------------------------------------------------------------------===//
844 // Hopper builders.
845 //===----------------------------------------------------------------------===//
846 
847 /// Helper to create the base Hopper-specific operations that are reused in
848 /// various other places.
851  : rewriter(rewriter), loc(loc) {}
852 
854  buildAndInitBarrierInSharedMemory(OpFoldResult numThreads);
855 
856  /// Create tma descriptor op to initiate transfer from global to shared
857  /// memory. This must be done before the launch op, on the host.
859  buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref,
860  gpu::LaunchOp launchOp);
861 
862  /// Build a tma load from global memory to shared memory using `barrier` to
863  /// synchronize. Return the number of bytes that will be transferred.
865  buildTmaAsyncLoad(TypedValue<nvgpu::TensorMapDescriptorType> globalDesc,
866  TypedValue<MemRefType> sharedMemref,
869  void buildBarrierArriveTx(TypedValue<nvgpu::MBarrierGroupType> barrier,
870  ArrayRef<OpFoldResult> sizes);
871 
872  /// If threadIdx.x == 0 does TMA request + wait, else just wait.
873  /// Return the operation that performs the transfer on thread0.
874  // TODO: In the future, don't hardcode to thread 0 but elect a leader.
875  SmallVector<Operation *> buildPredicateLoadsOnThread0(
877  ArrayRef<TypedValue<MemRefType>> sharedMemBuffers,
879 
880  void buildTryWaitParity(TypedValue<nvgpu::MBarrierGroupType> barrier);
881 
884 };
885 
887  ArrayRef<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescriptors,
888  ArrayRef<TypedValue<MemRefType>> sharedMemBuffers,
890  SmallVector<Operation *> loadOps;
891  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
892  Value tidx = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x);
893  Value cond =
894  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, tidx, zero);
895  // clang-format off
896  rewriter.create<scf::IfOp>(
897  /*location=*/loc,
898  /*conditional=*/cond,
899  /*thenBuilder=*/
900  [&](OpBuilder &lb, Location loc) {
902  sizes.reserve(globalDescriptors.size());
903  for (auto [desc, shmem] : llvm::zip_equal(
904  globalDescriptors, sharedMemBuffers)) {
905  OpFoldResult sz = buildTmaAsyncLoad(desc, shmem, barrier, loadOps);
906  sizes.push_back(sz);
907  }
908  // TODO: Note that cutlass predeclares the barrier arrive tx before the tma.async.load.
909  // This may or may not have perf implications.
910  buildBarrierArriveTx(barrier, sizes);
911  rewriter.create<scf::YieldOp>(loc);
912  },
913  /*elseBuilder=*/
914  [&](OpBuilder &lb, Location loc) {
915  // TODO: is this for no-thread divergence?
916  // Should we just yield the size and hoist?
917  buildBarrierArriveTx(barrier, getAsIndexOpFoldResult(rewriter.getContext(), 0));
918  rewriter.create<scf::YieldOp>(loc);
919  });
920  // clang-format on
921  return loadOps;
922 }
923 
926  b.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
927  // return b.getI64IntegerAttr(static_cast<int64_t>(kSharedMemorySpace));
928 }
929 
932  auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter);
933  Value barrier = rewriter.create<nvgpu::MBarrierCreateOp>(
934  loc,
935  nvgpu::MBarrierGroupType::get(rewriter.getContext(), sharedMemorySpace));
936  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
937  rewriter.create<nvgpu::MBarrierInitOp>(
938  loc, barrier, getValueOrCreateConstantIndexOp(rewriter, loc, numThreads),
939  zero, Value());
940  rewriter.create<gpu::BarrierOp>(loc);
941  return cast<TypedValue<nvgpu::MBarrierGroupType>>(barrier);
942 }
943 
946  gpu::LaunchOp launchOp) {
947  OpBuilder::InsertionGuard guard(rewriter);
948  rewriter.setInsertionPoint(launchOp);
949  Value unrankedMemRef = rewriter.create<memref::CastOp>(
950  loc,
951  UnrankedMemRefType::get(memref.getType().getElementType(),
952  memref.getType().getMemorySpace()),
953  memref);
954  SmallVector<OpFoldResult> mixedSizes =
955  memref::getMixedSizes(rewriter, loc, memref);
956  SmallVector<Value> sizes =
957  getValueOrCreateConstantIndexOp(rewriter, loc, mixedSizes);
958 
959  auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter);
960  Value desc = rewriter.create<nvgpu::TmaCreateDescriptorOp>(
961  loc,
963  rewriter.getContext(),
964  MemRefType::Builder(memref.getType())
965  .setMemorySpace(sharedMemorySpace),
966  TensorMapSwizzleKind::SWIZZLE_NONE,
967  TensorMapL2PromoKind::L2PROMO_NONE, TensorMapOOBKind::OOB_ZERO,
968  TensorMapInterleaveKind::INTERLEAVE_NONE),
969  unrankedMemRef, sizes);
970  return cast<TypedValue<nvgpu::TensorMapDescriptorType>>(desc);
971 }
972 
975  TypedValue<MemRefType> sharedMemref,
977  SmallVectorImpl<Operation *> &loadOps) {
978  MLIRContext *ctx = rewriter.getContext();
979  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
980  Operation *loadOp = rewriter.create<nvgpu::TmaAsyncLoadOp>(
981  loc, sharedMemref, barrier, globalDesc, ValueRange{zero, zero}, zero,
982  Value(), Value());
983  loadOps.push_back(loadOp);
984  auto mixedSizes = memref::getMixedSizes(rewriter, loc, sharedMemref);
985  SmallVector<AffineExpr> symbols(mixedSizes.size());
986  bindSymbolsList(ctx, llvm::MutableArrayRef{symbols});
987  AffineExpr prodExprInBytes =
988  computeProduct(ctx, symbols) *
989  (sharedMemref.getType().getElementTypeBitWidth() / 8);
990  auto res = affine::makeComposedFoldedAffineApply(rewriter, loc,
991  prodExprInBytes, mixedSizes);
992  return res;
993 }
994 
997  ArrayRef<OpFoldResult> mixedSizes) {
998  assert(!mixedSizes.empty() && "expecte non-empty sizes");
999  MLIRContext *ctx = rewriter.getContext();
1000  SmallVector<AffineExpr> symbols(mixedSizes.size());
1001  bindSymbolsList(ctx, llvm::MutableArrayRef{symbols});
1002  AffineExpr sumExpr = computeSum(ctx, symbols);
1003  OpFoldResult size =
1004  affine::makeComposedFoldedAffineApply(rewriter, loc, sumExpr, mixedSizes);
1005  Value sizeVal = getValueOrCreateConstantIndexOp(rewriter, loc, size);
1006  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1007  rewriter.create<nvgpu::MBarrierArriveExpectTxOp>(loc, barrier, sizeVal, zero,
1008  Value());
1009 }
1010 
1013  Type i1 = rewriter.getI1Type();
1014  Value parity = rewriter.create<LLVM::ConstantOp>(loc, i1, 0);
1015  // 10M is an arbitrary, not too small or too big number to specify the number
1016  // of ticks before retry.
1017  // TODO: hoist this in a default dialect constant.
1018  Value ticksBeforeRetry =
1019  rewriter.create<arith::ConstantIndexOp>(loc, 10000000);
1020  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1021  rewriter.create<nvgpu::MBarrierTryWaitParityOp>(loc, barrier, parity,
1022  ticksBeforeRetry, zero);
1023 }
1024 
1025 //===----------------------------------------------------------------------===//
1026 // RewriteCopyAsTmaOp
1027 //===----------------------------------------------------------------------===//
1028 
1029 /// Helper to create the tma operations corresponding to `linalg::CopyOp`.
1030 struct CopyBuilder : public HopperBuilder {
1032  : HopperBuilder(rewriter, loc) {}
1033 
1035 };
1036 
1037 SmallVector<Operation *> CopyBuilder::rewrite(ArrayRef<Operation *> copyOps) {
1038  MLIRContext *ctx = rewriter.getContext();
1039  if (copyOps.empty())
1040  return SmallVector<Operation *>();
1041 
1042  auto launchOp = copyOps.front()->getParentOfType<gpu::LaunchOp>();
1043  assert(launchOp && "expected launch op");
1044 
1045  // 1. Init a barrier object in shared memory.
1046  OpBuilder::InsertionGuard g(rewriter);
1047  rewriter.setInsertionPoint(copyOps.front());
1048  AffineExpr bx, by, bz;
1049  bindSymbols(ctx, bx, by, bz);
1050  AffineExpr prod = computeProduct(ctx, ArrayRef<AffineExpr>{bx, by, bz});
1052  rewriter, loc, prod,
1053  ArrayRef<OpFoldResult>{launchOp.getBlockSizeX(), launchOp.getBlockSizeY(),
1054  launchOp.getBlockSizeZ()});
1055 
1057  buildAndInitBarrierInSharedMemory(numThreads);
1058 
1061  for (Operation *op : copyOps) {
1062  auto copyOp = cast<linalg::CopyOp>(op);
1063  auto inMemRef =
1064  cast<TypedValue<MemRefType>>(copyOp.getDpsInputOperand(0)->get());
1065  assert(inMemRef.getType().getRank() == 2 &&
1066  "expected in to be a 2D memref");
1067 
1068  // 2. Build global memory descriptor.
1070  buildGlobalMemRefDescriptor(inMemRef, launchOp);
1071  globalDescs.push_back(globalDesc);
1072 
1073  // 3. Shared memory and descriptor for the tmp array.
1074  auto shmem =
1075  cast<TypedValue<MemRefType>>(copyOp.getDpsInitOperand(0)->get());
1076  shmems.push_back(shmem);
1077  }
1078 
1079  // 4. Load in from global memory to shared memory using tma.
1080  OpBuilder::InsertionGuard g2(rewriter);
1081  rewriter.setInsertionPoint(copyOps.front());
1082  SmallVector<Operation *> results =
1083  buildPredicateLoadsOnThread0(globalDescs, shmems, barrier);
1084 
1085  // 5. Spin-loop until data is ready.
1086  buildTryWaitParity(barrier);
1087 
1088  // 6. Erase the ops that have now been rewritten.
1089  for (Operation *op : copyOps)
1090  rewriter.eraseOp(op);
1091 
1092  return results;
1093 }
1094 
1096 transform::RewriteCopyAsTmaOp::apply(transform::TransformRewriter &rewriter,
1097  transform::TransformResults &results,
1098  transform::TransformState &state) {
1099  auto payloadOps = state.getPayloadOps(getTarget());
1100  gpu::LaunchOp commonLaunchOp;
1101  Operation *firstOp, *failingOp;
1102  if (llvm::any_of(payloadOps, [&](Operation *op) {
1103  if (!commonLaunchOp) {
1104  commonLaunchOp = op->getParentOfType<gpu::LaunchOp>();
1105  firstOp = op;
1106  }
1107  auto fail = !op->getParentOfType<gpu::LaunchOp>() ||
1108  commonLaunchOp != op->getParentOfType<gpu::LaunchOp>() ||
1109  !isa<linalg::CopyOp>(op);
1110  if (fail)
1111  failingOp = op;
1112  return fail;
1113  })) {
1115  emitSilenceableError()
1116  << "target ops must be linalg::CopyOp nested under a common "
1117  "gpu.LaunchOp to be rewritten because the tma descriptors need to "
1118  "be created on the host.\nBut got: "
1119  << *firstOp << "\nand " << *failingOp;
1120  return diag;
1121  }
1122 
1123  // TODO: more robust detection of copy, with transposes etc.
1124  CopyBuilder(rewriter, getLoc()).rewrite(llvm::to_vector(payloadOps));
1125 
1127 }
1128 
1129 //===----------------------------------------------------------------------===//
1130 // Transform op registration
1131 //===----------------------------------------------------------------------===//
1132 
1133 namespace {
1134 class NVGPUTransformDialectExtension
1136  NVGPUTransformDialectExtension> {
1137 public:
1138  NVGPUTransformDialectExtension() {
1139  declareGeneratedDialect<arith::ArithDialect>();
1140  declareGeneratedDialect<affine::AffineDialect>();
1141  declareGeneratedDialect<nvgpu::NVGPUDialect>();
1142  declareGeneratedDialect<NVVM::NVVMDialect>();
1143  declareGeneratedDialect<vector::VectorDialect>();
1144  registerTransformOps<
1145 #define GET_OP_LIST
1146 #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
1147  >();
1148  }
1149 };
1150 } // namespace
1151 
1152 #define GET_OP_CLASSES
1153 #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
1154 
1156  registry.addExtensions<NVGPUTransformDialectExtension>();
1157 }
static constexpr int64_t kSharedMemorySpace
static std::string diag(const llvm::Value &value)
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
Definition: NVGPUDialect.h:27
static Attribute getSharedAddressSpaceAttribute(OpBuilder &b)
static bool hasDefaultMemorySpace(BaseMemRefType type)
Returns true if the given type has the default memory space.
static LogicalResult collectStage0PipeliningOps(scf::ForOp forOp, llvm::SmallPtrSet< Operation *, 16 > &ops)
Populate ops with the set of operations that belong to the stage 0 of the pipelined version of the gi...
static Operation * replaceOpWithPredicatedOp(RewriterBase &rewriter, Operation *op, Value predicate)
Hook for the loop pipeliner.
static bool isStoreToShared(Operation *op, Value v)
Returns true if the operation is storing the given value into shared memory.
static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn, ReduceFn reduceFn)
Helper functions to create customizable load and stores operations.
static std::tuple< DiagnosedSilenceableFailure, scf::ForOp > pipelineForSharedCopies(RewriterBase &rewriter, scf::ForOp forOp, int64_t depth, bool epiloguePeeling)
Applies loop pipelining with the given depth to the given loop so that copies into the shared memory ...
static bool hasSharedMemorySpace(BaseMemRefType type)
Returns true if the given type has the shared (workgroup) memory space.
static bool isLoadFromGlobalStoredToShared(Operation *op)
Returns true if the operation is a load from the default memory space the result of which is only sto...
static std::tuple< SmallVector< int64_t >, SmallVector< int64_t >, SmallVector< int64_t > > makeVectorShapes(ArrayRef< int64_t > lhs, ArrayRef< int64_t > rhs, ArrayRef< int64_t > res)
static void setAsyncWaitGroupsInFlight(OpBuilder &builder, Operation *op, scf::PipeliningOption::PipelinerPart part, unsigned iteration, unsigned depth)
Hook for the loop pipeliner that sets the "num groups in flight" attribute of async wait operations c...
static void getPipelineStages(scf::ForOp forOp, std::vector< std::pair< Operation *, unsigned >> &opsWithPipelineStages, unsigned depth, llvm::SmallPtrSetImpl< Operation * > &stage0Ops)
Hook for the loop pipeliner that populates ops with the stage information as follows:
static Value getValueLoadedFromGlobal(Operation *op)
Returns the value produced by a load from the default memory space.
static llvm::ManagedStatic< PassManagerOptions > options
static VectorShape vectorShape(Type type)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
Base type for affine expression.
Definition: AffineExpr.h:69
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:883
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:138
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
FloatType getF32Type()
Definition: Builders.cpp:63
FloatType getF16Type()
Definition: Builders.cpp:59
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:73
IndexType getIndexType()
Definition: Builders.cpp:71
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:34
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
Definition: LLVMTypes.cpp:453
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:201
Builder & setMemorySpace(Attribute newMemorySpace)
Definition: BuiltinTypes.h:227
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Type conversion class.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:51
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:58
bool isF16() const
Definition: Types.cpp:49
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
user_range getUsers() const
Definition: Value.h:228
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:215
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void push_back(Operation *op)
Appends an element to the list.
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
@ kGlobalMemorySpace
Global memory space identifier.
Definition: NVVMDialect.h:36
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1188
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)
Return the memref type that can be used to represent an mbarrier object.
void registerTransformDialectExtension(DialectRegistry &registry)
void createAsyncGroups(RewriterBase &rewriter, Operation *op, bool bypassL1)
Convert global->shared vector transfers to async device copies.
FailureOr< ForOp > pipelineForLoop(RewriterBase &rewriter, ForOp forOp, const PipeliningOption &options, bool *modifiedIR=nullptr)
Generate a pipelined version of the scf.for loop based on the schedule given as option.
void consumesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void producesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
Definition: IndexingUtils.h:47
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
void getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
void populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:363
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:41
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:599
int64_t computeSum(ArrayRef< int64_t > basis)
Self-explicit.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
Definition: AffineExpr.h:368
Helper to create the tma operations corresponding to linalg::CopyOp.
SmallVector< Operation * > rewrite(ArrayRef< Operation * > copyOps)
CopyBuilder(RewriterBase &rewriter, Location loc)
Helper to create the base Hopper-specific operations that are reused in various other places.
OpFoldResult buildTmaAsyncLoad(TypedValue< nvgpu::TensorMapDescriptorType > globalDesc, TypedValue< MemRefType > sharedMemref, TypedValue< nvgpu::MBarrierGroupType > barrier, SmallVectorImpl< Operation * > &loadOps)
Build a tma load from global memory to shared memory using barrier to synchronize.
TypedValue< nvgpu::MBarrierGroupType > buildAndInitBarrierInSharedMemory(OpFoldResult numThreads)
void buildTryWaitParity(TypedValue< nvgpu::MBarrierGroupType > barrier)
RewriterBase & rewriter
TypedValue< nvgpu::TensorMapDescriptorType > buildGlobalMemRefDescriptor(TypedValue< MemRefType > memref, gpu::LaunchOp launchOp)
Create tma descriptor op to initiate transfer from global to shared memory.
SmallVector< Operation * > buildPredicateLoadsOnThread0(ArrayRef< TypedValue< nvgpu::TensorMapDescriptorType >> globalDescriptors, ArrayRef< TypedValue< MemRefType >> sharedMemBuffers, TypedValue< nvgpu::MBarrierGroupType > barrier)
If threadIdx.x == 0 does TMA request + wait, else just wait.
void buildBarrierArriveTx(TypedValue< nvgpu::MBarrierGroupType > barrier, ArrayRef< OpFoldResult > sizes)
HopperBuilder(RewriterBase &rewriter, Location loc)
Helper struct to provide a simple mapping from matmul operations to the corresponding mma....
std::function< SmallVector< RowColIndexing >(MLIRContext *)> IndexCalculator
MmaSyncBuilder(OpBuilder &b, Location loc, OpFoldResult laneId)
FailureOr< Operation * > buildMmaSync(LinalgOp linalgOp)
Create the mma.sync operation corresponding to linalgOp along with all the supporting load/store and ...
Helper struct to encode a pair of row/column indexings in the form of affine expressions.
AffineExpr col() const
RowColIndexing(AffineExpr row, AffineExpr col)
void print(llvm::raw_ostream &os) const
AffineExpr row() const
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Options to dictate how loops should be pipelined.
Definition: Transforms.h:120