MLIR  21.0.0git
NVGPUTransformOps.cpp
Go to the documentation of this file.
1 //===- NVGPUTransformOps.cpp - Implementation of NVGPU transform ops ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
29 #include "mlir/IR/AffineExpr.h"
30 #include "mlir/IR/BuiltinTypes.h"
31 #include "mlir/IR/Value.h"
32 #include "llvm/ADT/ArrayRef.h"
33 
34 using namespace mlir;
35 using namespace mlir::linalg;
36 using namespace mlir::nvgpu;
37 using namespace mlir::NVVM;
38 using namespace mlir::transform;
39 
40 #define DEBUG_TYPE "nvgpu-transforms"
41 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
42 #define DBGSNL() (llvm::dbgs() << "\n")
43 #define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
44 
45 //===----------------------------------------------------------------------===//
46 // Apply...ConversionPatternsOp
47 //===----------------------------------------------------------------------===//
48 
49 void transform::ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns(
50  TypeConverter &typeConverter, RewritePatternSet &patterns) {
51  auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
52  /// device-side async tokens cannot be materialized in nvvm. We just
53  /// convert them to a dummy i32 type in order to easily drop them during
54  /// conversion.
56  llvmTypeConverter, [](gpu::AddressSpace space) -> unsigned {
57  switch (space) {
58  case gpu::AddressSpace::Global:
59  return static_cast<unsigned>(
61  case gpu::AddressSpace::Workgroup:
62  return static_cast<unsigned>(
64  case gpu::AddressSpace::Private:
65  return 0;
66  }
67  llvm_unreachable("unknown address space enum value");
68  return 0;
69  });
70  llvmTypeConverter.addConversion(
71  [&](nvgpu::DeviceAsyncTokenType type) -> Type {
72  return llvmTypeConverter.convertType(
73  IntegerType::get(type.getContext(), 32));
74  });
75  llvmTypeConverter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type {
76  return llvmTypeConverter.convertType(
77  IntegerType::get(type.getContext(), 64));
78  });
79  llvmTypeConverter.addConversion(
80  [&](nvgpu::WarpgroupAccumulatorType type) -> Type {
81  Type elemType = type.getFragmented().getElementType();
82  int64_t sizeM = type.getFragmented().getDimSize(0);
83  int64_t sizeN = type.getFragmented().getDimSize(1);
84 
85  unsigned numMembers;
86  if (elemType.isF32() || elemType.isInteger(32))
87  numMembers = sizeN / 2;
88  else if (elemType.isF16())
89  numMembers = sizeN / 4;
90  else
91  llvm_unreachable("unsupported type for warpgroup accumulator");
92 
93  SmallVector<Type> innerStructBody;
94  for (unsigned i = 0; i < numMembers; i++)
95  innerStructBody.push_back(elemType);
96  auto innerStructType = LLVM::LLVMStructType::getLiteral(
97  type.getContext(), innerStructBody);
98 
99  SmallVector<Type> structBody;
100  for (int i = 0; i < sizeM; i += kWgmmaSizeM)
101  structBody.push_back(innerStructType);
102 
103  auto convertedType =
104  LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
105  return llvmTypeConverter.convertType(convertedType);
106  });
107  llvmTypeConverter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type {
108  return llvmTypeConverter.convertType(
109  getMBarrierMemrefType(type.getContext(), type));
110  });
111  llvmTypeConverter.addConversion(
112  [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type {
113  return llvmTypeConverter.convertType(
114  IntegerType::get(type.getContext(), 64));
115  });
116  llvmTypeConverter.addConversion(
117  [&](nvgpu::TensorMapDescriptorType type) -> Type {
118  return LLVM::LLVMPointerType::get(type.getContext());
119  });
121 }
122 
123 LogicalResult
124 transform::ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter(
125  transform::TypeConverterBuilderOpInterface builder) {
126  if (builder.getTypeConverterType() != "LLVMTypeConverter")
127  return emitOpError("expected LLVMTypeConverter");
128  return success();
129 }
130 
131 //===---------------------------------------------------------------------===//
132 // CreateAsyncGroupsOp
133 //===---------------------------------------------------------------------===//
134 
135 void transform::CreateAsyncGroupsOp::getEffects(
137  transform::consumesHandle(getTargetMutable(), effects);
138  transform::producesHandle(getOperation()->getOpResults(), effects);
140 }
141 
142 DiagnosedSilenceableFailure transform::CreateAsyncGroupsOp::applyToOne(
143  TransformRewriter &rewriter, Operation *target,
144  ApplyToEachResultList &results, TransformState &state) {
145  nvgpu::createAsyncGroups(rewriter, target, getBypassL1());
146  results.push_back(target);
148 }
149 
150 //===----------------------------------------------------------------------===//
151 // PipelineSharedMemoryCopiesOp
152 //===----------------------------------------------------------------------===//
153 
154 /// Returns true if the given type has the default memory space.
156  return !type.getMemorySpace() || type.getMemorySpaceAsInt() == 0;
157 }
158 
159 /// Returns true if the given type has the shared (workgroup) memory space.
161  auto space =
162  dyn_cast_if_present<gpu::AddressSpaceAttr>(type.getMemorySpace());
163  return space &&
164  space.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
165 }
166 
167 /// Returns the value produced by a load from the default memory space. Returns
168 /// null if the operation is not such a load.
170  // TODO: consider an interface or leveraging the memory effects interface.
171  auto load = dyn_cast<vector::TransferReadOp>(op);
172  if (!load)
173  return nullptr;
174 
175  auto loadType = dyn_cast<MemRefType>(load.getBase().getType());
176  if (!loadType || !hasDefaultMemorySpace(loadType))
177  return nullptr;
178  return load;
179 }
180 
181 /// Returns true if the operation is storing the given value into shared memory.
182 static bool isStoreToShared(Operation *op, Value v) {
183  // TOD: consider an interface or leveraging the memory effects interface.
184  auto store = dyn_cast<vector::TransferWriteOp>(op);
185  if (!store || store.getVector() != v)
186  return false;
187 
188  auto storeType = dyn_cast<MemRefType>(store.getBase().getType());
189  return storeType || hasSharedMemorySpace(storeType);
190 }
191 
192 /// Returns true if the operation is a load from the default memory space the
193 /// result of which is only stored into the shared memory space.
195  Value loaded = getValueLoadedFromGlobal(op);
196  if (!loaded || !loaded.hasOneUse())
197  return false;
198 
199  return isStoreToShared(*loaded.getUsers().begin(), loaded);
200 }
201 
202 /// Populate `ops` with the set of operations that belong to the stage 0 of the
203 /// pipelined version of the given loop when pipelining copies to shared memory.
204 /// Specifically, this collects:
205 ///
206 /// 1. all loads from global memory, both sync and async;
207 /// 2. the barriers for async loads.
208 ///
209 /// In particular, barriers are omitted if they do not dominate at least one
210 /// async load for which there is not yet a barrier.
211 static LogicalResult
212 collectStage0PipeliningOps(scf::ForOp forOp,
214 
216  for (Operation &op : *forOp.getBody()) {
217  // Bail on nested ops for now.
218  if (op.getNumRegions() > 0)
219  return failure();
220 
221  if (isa<gpu::BarrierOp>(op)) {
222  barriers.insert(&op);
223  continue;
224  }
225 
226  if (isa<nvgpu::DeviceAsyncCopyOp, nvgpu::DeviceAsyncCreateGroupOp>(op)) {
227  ops.insert(&op);
228  ops.insert(std::make_move_iterator(barriers.begin()),
229  std::make_move_iterator(barriers.end()));
230  assert(barriers.empty() &&
231  "expected to have moved the barriers into another set");
232  continue;
233  }
234 
236  ops.insert(&op);
237  continue;
238  }
239  }
240 
241  return success();
242 }
243 
244 /// Hook for the loop pipeliner that sets the "num groups in flight" attribute
245 /// of async wait operations corresponding to pipelined shared memory copies.
246 // TODO: this currently assumes that there are no groups that could be in flight
247 // in the existing code.
248 static void
251  unsigned iteration, unsigned depth) {
252  // Based on the order of copies within the loop we need to set the number
253  // of copies in flight, unless it is already set.
254  auto waitOp = dyn_cast<nvgpu::DeviceAsyncWaitOp>(op);
255  if (!waitOp || waitOp.getNumGroups())
256  return;
257 
258  int numGroupInFlight = 0;
261  numGroupInFlight = depth - 1;
262  } else {
263  // By construction there should be no wait op in the prologue as all the
264  // wait should be in the last stage.
266  // Based on the schedule we pick we know how many groups are in flight for
267  // each iteration of the epilogue.
268  numGroupInFlight = depth - 1 - iteration;
269  }
270  waitOp.setNumGroups(numGroupInFlight);
271 }
272 
273 /// Hook for the loop pipeliner that populates `ops` with the stage information
274 /// as follows:
275 ///
276 /// - operations in `stage0Ops` (typically loads from global memory and
277 /// related barriers) are at stage 0;
278 /// - operations in the backward slice of any stage0Ops are all at stage 0;
279 /// - other operations are at stage `depth`;
280 /// - the internal order of the pipelined loop has ops at stage `depth` first,
281 /// then those at stage 0, with relative order within each group preserved.
282 ///
283 static void getPipelineStages(
284  scf::ForOp forOp,
285  std::vector<std::pair<Operation *, unsigned>> &opsWithPipelineStages,
286  unsigned depth, llvm::SmallPtrSetImpl<Operation *> &stage0Ops) {
287  SetVector<Operation *> dependencies;
288  BackwardSliceOptions options([&](Operation *visited) {
289  return visited->getBlock() == forOp.getBody();
290  });
291  options.inclusive = true;
292  for (Operation &op : forOp.getBody()->getOperations()) {
293  if (stage0Ops.contains(&op)) {
294  LogicalResult result = getBackwardSlice(&op, &dependencies, options);
295  assert(result.succeeded() && "expected a backward slice");
296  (void)result;
297  }
298  }
299 
300  for (Operation &op : forOp.getBody()->getOperations()) {
301  if (!dependencies.contains(&op) && !isa<scf::YieldOp>(op))
302  opsWithPipelineStages.emplace_back(&op, depth);
303  }
304  for (Operation &op : forOp.getBody()->getOperations()) {
305  if (dependencies.contains(&op))
306  opsWithPipelineStages.emplace_back(&op, 0);
307  }
308 }
309 
310 /// Hook for the loop pipeliner. Replaces op with a predicated version and
311 /// returns the resulting operation. Returns the original op if the predication
312 /// isn't necessary for the given op. Returns null if predication is needed but
313 /// not supported.
315  Operation *op, Value predicate) {
316  // Some operations may be fine to execute "speculatively" more times than the
317  // original number of iterations, in particular side-effect free operations
318  // and barriers, even if they cannot be predicated.
319  if (isMemoryEffectFree(op) ||
320  isa<gpu::BarrierOp, nvgpu::DeviceAsyncCreateGroupOp,
321  nvgpu::DeviceAsyncWaitOp>(op)) {
322  return op;
323  }
324 
325  // Otherwise, only async copies can currently be predicated.
326  auto asyncCopyOp = dyn_cast<nvgpu::DeviceAsyncCopyOp>(op);
327  if (!asyncCopyOp)
328  return nullptr;
329 
330  // Create srcElement Value based on `predicate`. The next lines generate
331  // the following code:
332  //
333  // srcElement = (pred) ? prevSrcElements : 0;
334  //
335  Location loc = asyncCopyOp->getLoc();
336  Value dstElements =
337  rewriter.create<arith::ConstantOp>(loc, asyncCopyOp.getDstElementsAttr());
338  Value originalSrcElement =
339  asyncCopyOp.getSrcElements() ? asyncCopyOp.getSrcElements() : dstElements;
340  Value c0Index = rewriter.create<arith::ConstantIndexOp>(loc, 0);
341  auto srcElements = rewriter.create<arith::SelectOp>(
342  loc, predicate, originalSrcElement, c0Index);
343  auto asyncCopyZeroFillOp = rewriter.create<nvgpu::DeviceAsyncCopyOp>(
344  loc, nvgpu::DeviceAsyncTokenType::get(asyncCopyOp.getContext()),
345  asyncCopyOp.getDst(), asyncCopyOp.getDstIndices(), asyncCopyOp.getSrc(),
346  asyncCopyOp.getSrcIndices(), asyncCopyOp.getDstElements(), srcElements,
347  UnitAttr());
348  rewriter.replaceOp(asyncCopyOp, asyncCopyZeroFillOp);
349  return asyncCopyZeroFillOp;
350 }
351 
352 /// Applies loop pipelining with the given depth to the given loop so that
353 /// copies into the shared memory are pipelined. Doesn't affect other loops.
354 /// Returns a pair containing the error state and the pipelined op, the latter
355 /// being null in case of any failure. The error state contains a definite error
356 /// if the IR has been modified and a silenceable error otherwise.
357 static std::tuple<DiagnosedSilenceableFailure, scf::ForOp>
358 pipelineForSharedCopies(RewriterBase &rewriter, scf::ForOp forOp, int64_t depth,
359  bool epiloguePeeling) {
361  if (failed(collectStage0PipeliningOps(forOp, stage0Ops))) {
362  return std::make_tuple(
363  emitSilenceableFailure(forOp, "cannot find stage 0 ops for pipelining"),
364  scf::ForOp());
365  }
366  if (stage0Ops.empty()) {
367  return std::make_tuple(
368  emitSilenceableFailure(forOp, "no shared memory copy"), scf::ForOp());
369  }
370 
372  unsigned maxDepth = depth;
373  auto setAnnotation = [&](Operation *op,
375  unsigned iteration) {
376  return setAsyncWaitGroupsInFlight(rewriter, op, part, iteration, maxDepth);
377  };
378  options.getScheduleFn =
379  [&](scf::ForOp schedulingFor,
380  std::vector<std::pair<Operation *, unsigned>> &ops) {
381  if (schedulingFor != forOp)
382  return;
383  return getPipelineStages(forOp, ops, maxDepth, stage0Ops);
384  };
385  options.annotateFn = setAnnotation;
386  if (!epiloguePeeling) {
387  options.peelEpilogue = false;
388  options.predicateFn = replaceOpWithPredicatedOp;
389  }
390 
391  OpBuilder::InsertionGuard guard(rewriter);
392  rewriter.setInsertionPoint(forOp);
393  bool modifiedIR;
394  FailureOr<scf::ForOp> maybePipelined =
395  pipelineForLoop(rewriter, forOp, options, &modifiedIR);
396  if (succeeded(maybePipelined)) {
397  return std::make_tuple(DiagnosedSilenceableFailure::success(),
398  *maybePipelined);
399  }
400  return std::make_tuple(
401  modifiedIR
403  : emitSilenceableFailure(forOp, "pipelining preconditions failed"),
404  scf::ForOp());
405 }
406 
407 DiagnosedSilenceableFailure PipelineSharedMemoryCopiesOp::applyToOne(
408  TransformRewriter &rewriter, scf::ForOp forOp,
409  ApplyToEachResultList &results, TransformState &state) {
410  auto [diag, pipelined] = pipelineForSharedCopies(
411  rewriter, forOp, static_cast<int64_t>(getDepth()), getPeelEpilogue());
412  if (diag.succeeded()) {
413  results.push_back(pipelined);
415  }
416  if (diag.isDefiniteFailure()) {
417  auto diag = emitDefiniteFailure("irreversible pipelining failure");
418  if (!getPeelEpilogue()) {
419  diag.attachNote(forOp->getLoc()) << "couldn't predicate?";
420  diag.attachNote(getLoc()) << "try setting " << getPeelEpilogueAttrName();
421  }
422  return diag;
423  }
424 
425  return std::move(diag);
426 }
427 
428 //===----------------------------------------------------------------------===//
429 // RewriteMatmulAsMmaSyncOp
430 //===----------------------------------------------------------------------===//
431 
432 /// Helper struct to encode a pair of row/column indexings in the form of
433 /// affine expressions.
434 struct RowColIndexing : private std::pair<AffineExpr, AffineExpr> {
436  : std::pair<AffineExpr, AffineExpr>(row, col) {}
437 
438  AffineExpr row() const { return first; };
439  AffineExpr col() const { return second; };
440 
441  void print(llvm::raw_ostream &os) const {
442  os << "- indexing: " << first << ", " << second;
443  }
444 };
445 
446 /// Helper struct to provide a simple mapping from matmul operations to the
447 /// corresponding mma.sync operation. This is constrained to the case where the
448 /// matmul matches the mma.sync operation 1-1.
451  : b(b), loc(loc), laneId(laneId) {}
452 
454  std::function<SmallVector<RowColIndexing>(MLIRContext *)>;
455 
456  /// Create the mma.sync operation corresponding to `linalgOp` along with all
457  /// the supporting load/store and vector operations.
458  FailureOr<Operation *> buildMmaSync(LinalgOp linalgOp);
459 
460 private:
461  struct MmaSyncInfo {
462  std::tuple<IndexCalculator, IndexCalculator, IndexCalculator> indexFns;
463  std::tuple<SmallVector<int64_t>, SmallVector<int64_t>, SmallVector<int64_t>>
464  vectorShapes;
465  SmallVector<int64_t> mmaShape;
466  bool tf32Enabled;
467  };
468 
469  /// Return the specific index calculator for the given `linalgOp` or failure
470  /// if the op is not supported. This is the toplevel switch that should just
471  /// be Tablegen'd in the future.
472  FailureOr<MmaSyncInfo> getIndexCalculators(ArrayRef<int64_t> opShape,
473  TypeRange elementalTypes);
474 
475  //===--------------------------------------------------------------------===//
476  // Instruction-specific row, column indexing expression builders.
477  // These should all be declaratively specified via Tablegen in the future.
478  // The Tablegen specification should be as straightforward as possible to
479  // only model the existing size and type combinations.
480  //===--------------------------------------------------------------------===//
481  //
482  // TODO: Tablegen all this.
483  //===--------------------------------------------------------------------===//
484  // m16n8k4 tf32 case.
485  //===--------------------------------------------------------------------===//
486  /// From the NVIDIA doc:
487  /// groupID = %laneid >> 2
488  /// threadIDInGroup = %laneid % 4
489  /// row = groupID for a0
490  /// groupID + 8 for a1
491  /// col = threadIDInGroup
492  static SmallVector<RowColIndexing> m16n8k4tf32Lhs(MLIRContext *ctx) {
493  auto dim = getAffineDimExpr(0, ctx);
494  AffineExpr groupID = dim.floorDiv(4);
495  AffineExpr threadIDInGroup = dim % 4;
496  return {RowColIndexing{groupID, threadIDInGroup},
497  RowColIndexing{groupID + 8, threadIDInGroup}};
498  }
499 
500  /// From the NVIDIA doc:
501  /// groupID = %laneid >> 2
502  /// threadIDInGroup = %laneid % 4
503  /// row = threadIDInGroup
504  /// col = groupID
505  static SmallVector<RowColIndexing> m16n8k4tf32Rhs(MLIRContext *ctx) {
506  auto dim = getAffineDimExpr(0, ctx);
507  AffineExpr groupID = dim.floorDiv(4);
508  AffineExpr threadIDInGroup = dim % 4;
509  return {RowColIndexing{threadIDInGroup, groupID}};
510  }
511 
512  /// From the NVIDIA doc:
513  /// groupID = %laneid >> 2
514  /// threadIDInGroup = %laneid % 4
515  /// row = groupID for c0 and c1
516  /// groupID + 8 for c2 and c3
517  /// col = (threadIDInGroup * 2) + (i & 0x1) for ci where i = {0,..,3}
518  static SmallVector<RowColIndexing> m16n8k4tf32Res(MLIRContext *ctx) {
519  auto dim = getAffineDimExpr(0, ctx);
520  AffineExpr groupID = dim.floorDiv(4);
521  AffineExpr threadIDInGroup = dim % 4;
522  return {RowColIndexing{groupID, threadIDInGroup * 2 + 0},
523  RowColIndexing{groupID, threadIDInGroup * 2 + 1},
524  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0},
525  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}};
526  }
527 
528  //===--------------------------------------------------------------------===//
529  // m16n8k16 f16 case.
530  //===--------------------------------------------------------------------===//
531  /// From the NVIDIA doc:
532  /// groupID = %laneid >> 2
533  /// threadIDInGroup = %laneid % 4
534  ///
535  /// row = groupID for ai where 0 <= i < 2 || 4 <= i < 6
536  /// groupID + 8 Otherwise
537  ///
538  /// col = (threadIDInGroup * 2) + (i & 0x1) for ai where i < 4
539  /// (threadIDInGroup * 2) + (i & 0x1) + 8 for ai where i >= 4
540  static SmallVector<RowColIndexing> m16n8k16f16Lhs(MLIRContext *ctx) {
541  auto dim = getAffineDimExpr(0, ctx);
542  AffineExpr groupID = dim.floorDiv(4);
543  AffineExpr threadIDInGroup = dim % 4;
544  // clang-format off
545  return {
546  RowColIndexing{groupID, threadIDInGroup * 2 + 0}, // i == 0
547  RowColIndexing{groupID, threadIDInGroup * 2 + 1}, // i == 1
548  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, // i == 2
549  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}, // i == 3
550  RowColIndexing{groupID, threadIDInGroup * 2 + 0 + 8}, // i == 4
551  RowColIndexing{groupID, threadIDInGroup * 2 + 1 + 8}, // i == 5
552  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0 + 8}, // i == 6
553  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1 + 8} // i == 7
554  };
555  // clang-format on
556  }
557 
558  /// From the NVIDIA doc:
559  /// groupID = %laneid >> 2
560  /// threadIDInGroup = %laneid % 4
561  ///
562  /// row = (threadIDInGroup * 2) + (i & 0x1) for bi where i < 2
563  /// (threadIDInGroup * 2) + (i & 0x1) + 8 for bi where i >= 2
564  ///
565  /// col = groupID
566  static SmallVector<RowColIndexing> m16n8k16f16Rhs(MLIRContext *ctx) {
567  auto dim = getAffineDimExpr(0, ctx);
568  AffineExpr groupID = dim.floorDiv(4);
569  AffineExpr threadIDInGroup = dim % 4;
570  // clang-format off
571  return {
572  RowColIndexing{threadIDInGroup * 2 + 0, groupID}, // i == 0
573  RowColIndexing{threadIDInGroup * 2 + 1, groupID}, // i == 1
574  RowColIndexing{threadIDInGroup * 2 + 0 + 8, groupID}, // i == 2
575  RowColIndexing{threadIDInGroup * 2 + 1 + 8, groupID} // i == 3
576  };
577  // clang-format on
578  }
579 
580  /// From the NVIDIA doc:
581  /// groupID = %laneid >> 2
582  /// threadIDInGroup = %laneid % 4
583  ///
584  /// row = groupID for ci where i < 2
585  /// groupID + 8 for ci where i >= 2
586  ///
587  /// col = (threadIDInGroup * 2) + (i & 0x1) for ci where i = {0,..,3}
588  static SmallVector<RowColIndexing> m16n8k16f16Res(MLIRContext *ctx) {
589  auto dim = getAffineDimExpr(0, ctx);
590  AffineExpr groupID = dim.floorDiv(4);
591  AffineExpr threadIDInGroup = dim % 4;
592  // clang-format off
593  return {
594  RowColIndexing{groupID, threadIDInGroup * 2 + 0}, // i == 0
595  RowColIndexing{groupID, threadIDInGroup * 2 + 1}, // i == 1
596  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, // i == 2
597  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1} // i == 3
598  };
599  // clang-format on
600  }
601 
602  //===--------------------------------------------------------------------===//
603  /// Helper functions to create customizable load and stores operations. The
604  /// specific shapes of each MMA instruction are passed via the
605  /// IndexCalculator callback.
606  //===--------------------------------------------------------------------===//
607  /// Build a list of memref.load operations indexed at `(row, col)` indices
608  /// that make sense for a particular MMA instruction and specified via the
609  /// IndexCalculator callback.
610  SmallVector<Value> buildMemRefLoads(OpBuilder &b, Location loc,
611  OpFoldResult laneId, Value memref,
612  const IndexCalculator &indexFn);
613 
614  /// Perform a distributed load of a vector operand of `vectorShape` for a
615  /// particular MMA instruction whose `(row, col)` indices are specified via
616  /// the IndexCalculator callback. Each `laneId` loads the subportion of the
617  /// data that makes sense for the particular MMA operation.
618  /// The `vectorShape` matches existing NVGPU dialect op specification but
619  /// could also be flattened in the future if needed for simplification.
620  Value buildMmaSyncMemRefLoadOperand(OpBuilder &b, Location loc,
621  OpFoldResult laneId, Value memref,
622  IndexCalculator indexFn,
624 
625  /// Build a list of memref.store operations indexed at `(row, col)` indices
626  /// that make sense for a particular MMA instruction and specified via the
627  /// IndexCalculator callback.
628  SmallVector<Operation *> buildMemRefStores(OpBuilder &b, Location loc,
629  ValueRange toStore,
630  OpFoldResult laneId, Value memref,
631  const IndexCalculator &indexFn);
632 
633  /// Perform a distributed store of a vector operand of `vectorShape` for a
634  /// particular MMA instruction whose `(row, col)` indices are specified via
635  /// the IndexCalculator callback. Each `laneId` loads the subportion of the
636  /// data that makes sense for the particular MMA operation.
637  /// The `vectorShape` matches existing NVGPU dialect op specification but
638  /// could also be flattened in the future if needed for simplification.
639  SmallVector<Operation *> buildMmaSyncMemRefStoreOperand(
640  OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId,
641  Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape);
642 
643  OpBuilder &b;
644  Location loc;
645  OpFoldResult laneId;
646 };
647 
648 //===--------------------------------------------------------------------===//
649 /// Helper functions to create customizable load and stores operations. The
650 /// specific shapes of each MMA instruction are passed via the
651 /// IndexCalculator callback.
652 //===--------------------------------------------------------------------===//
653 
654 template <typename ApplyFn, typename ReduceFn>
655 static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn,
656  ReduceFn reduceFn) {
657  VectorType vectorType = cast<VectorType>(vector.getType());
658  auto vectorShape = vectorType.getShape();
659  auto strides = computeStrides(vectorShape);
660  for (int64_t idx = 0, e = vectorShape[0] * strides[0]; idx < e; ++idx) {
661  auto indices = delinearize(idx, strides);
662  reduceFn(applyFn(vector, idx, indices), idx, indices);
663  }
664 }
665 
667 MmaSyncBuilder::buildMemRefLoads(OpBuilder &b, Location loc,
668  OpFoldResult laneId, Value memref,
669  const IndexCalculator &indexFn) {
670  auto aff = [&](AffineExpr e) {
671  return affine::makeComposedFoldedAffineApply(b, loc, e, laneId);
672  };
673  SmallVector<Value> res;
674  SmallVector<RowColIndexing> indexings = indexFn(b.getContext());
675  for (auto indexing : indexings) {
676  Value row = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.row()));
677  Value col = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.col()));
678  auto load = b.create<memref::LoadOp>(loc, memref, ValueRange{row, col});
679  res.push_back(load);
680  }
681  return res;
682 }
683 
684 Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand(
685  OpBuilder &b, Location loc, OpFoldResult laneId, Value memref,
686  IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) {
687  auto loads = buildMemRefLoads(b, loc, laneId, memref, std::move(indexFn));
688 
689  Type elementType = getElementTypeOrSelf(memref.getType());
690  auto vt = VectorType::get(vectorShape, elementType);
691  Value res = b.create<vector::SplatOp>(loc, vt, loads[0]);
693  res,
694  /*applyFn=*/
695  [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
696  return loads[linearIdx];
697  },
698  /*reduceFn=*/
699  [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
700  res = b.create<vector::InsertOp>(loc, v, res, indices);
701  });
702 
703  return res;
704 }
705 
706 SmallVector<Operation *> MmaSyncBuilder::buildMemRefStores(
707  OpBuilder &b, Location loc, ValueRange toStore, OpFoldResult laneId,
708  Value memref, const IndexCalculator &indexFn) {
709  auto aff = [&](AffineExpr e) {
710  return affine::makeComposedFoldedAffineApply(b, loc, e, laneId);
711  };
713  for (auto [indexing, val] :
714  llvm::zip_equal(indexFn(b.getContext()), toStore)) {
715  Value row = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.row()));
716  Value col = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.col()));
717  Operation *store =
718  b.create<memref::StoreOp>(loc, val, memref, ValueRange{row, col});
719  res.push_back(store);
720  }
721  return res;
722 }
723 
724 SmallVector<Operation *> MmaSyncBuilder::buildMmaSyncMemRefStoreOperand(
725  OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId,
726  Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) {
727  SmallVector<Value> toStore;
728  toStore.reserve(32);
730  vectorToStore,
731  /*applyFn=*/
732  [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
733  return b.create<vector::ExtractOp>(loc, vectorToStore, indices);
734  },
735  /*reduceFn=*/
736  [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
737  toStore.push_back(v);
738  });
739  return buildMemRefStores(b, loc, toStore, laneId, memref, std::move(indexFn));
740 }
741 
742 static std::tuple<SmallVector<int64_t>, SmallVector<int64_t>,
745  ArrayRef<int64_t> res) {
746  SmallVector<int64_t> vlhs(lhs);
747  SmallVector<int64_t> vrhs(rhs);
748  SmallVector<int64_t> vres(res);
749  return std::make_tuple(vlhs, vrhs, vres);
750 }
751 
752 FailureOr<MmaSyncBuilder::MmaSyncInfo>
753 MmaSyncBuilder::getIndexCalculators(ArrayRef<int64_t> opShape,
754  TypeRange elementalTypes) {
755  // TODO: Tablegen all this.
756  Type f16 = b.getF16Type();
757  Type f32 = b.getF32Type();
758  if (opShape == ArrayRef<int64_t>{16, 8, 4} &&
759  elementalTypes == TypeRange{f32, f32, f32}) {
760  return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k4tf32Lhs,
761  &MmaSyncBuilder::m16n8k4tf32Rhs,
762  &MmaSyncBuilder::m16n8k4tf32Res),
763  makeVectorShapes({2, 1}, {1, 1}, {2, 2}),
764  SmallVector<int64_t>{opShape},
765  /*tf32Enabled=*/true};
766  }
767  // This is the version with f16 accumulation.
768  // TODO: version with f32 accumulation.
769  if (opShape == ArrayRef<int64_t>{16, 8, 16} &&
770  elementalTypes == TypeRange{f16, f16, f16}) {
771  return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k16f16Lhs,
772  &MmaSyncBuilder::m16n8k16f16Rhs,
773  &MmaSyncBuilder::m16n8k16f16Res),
774  makeVectorShapes({4, 2}, {2, 2}, {2, 2}),
775  SmallVector<int64_t>{opShape},
776  /*tf32Enabled=*/false};
777  }
778  return failure();
779 }
780 
781 FailureOr<Operation *> MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) {
782  Value lhsMemRef = linalgOp.getDpsInputOperand(0)->get();
783  Value rhsMemRef = linalgOp.getDpsInputOperand(1)->get();
784  Value resMemRef = linalgOp.getDpsInitOperand(0)->get();
785  assert(cast<MemRefType>(lhsMemRef.getType()).getRank() == 2 &&
786  "expected lhs to be a 2D memref");
787  assert(cast<MemRefType>(rhsMemRef.getType()).getRank() == 2 &&
788  "expected rhs to be a 2D memref");
789  assert(cast<MemRefType>(resMemRef.getType()).getRank() == 2 &&
790  "expected res to be a 2D memref");
791 
792  int64_t m = cast<MemRefType>(lhsMemRef.getType()).getShape()[0];
793  int64_t n = cast<MemRefType>(rhsMemRef.getType()).getShape()[1];
794  int64_t k = cast<MemRefType>(lhsMemRef.getType()).getShape()[1];
795  Type lhsType = getElementTypeOrSelf(lhsMemRef.getType());
796  Type rhsType = getElementTypeOrSelf(rhsMemRef.getType());
797  Type resType = getElementTypeOrSelf(resMemRef.getType());
798 
799  FailureOr<MmaSyncInfo> maybeInfo =
800  getIndexCalculators({m, n, k}, {lhsType, rhsType, resType});
801  if (failed(maybeInfo))
802  return failure();
803 
804  MmaSyncInfo info = *maybeInfo;
805  auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns;
806  auto [lhsShape, rhsShape, resShape] = info.vectorShapes;
807  Value lhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, lhsMemRef,
808  lhsIndexFn, lhsShape);
809  Value rhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, rhsMemRef,
810  rhsIndexFn, rhsShape);
811  Value res = buildMmaSyncMemRefLoadOperand(b, loc, laneId, resMemRef,
812  resIndexFn, resShape);
813  res = b.create<nvgpu::MmaSyncOp>(loc, lhs, rhs, res, info.mmaShape,
814  info.tf32Enabled);
815  buildMmaSyncMemRefStoreOperand(b, loc, res, laneId, resMemRef, resIndexFn,
816  resShape);
817  return res.getDefiningOp();
818 }
819 
820 DiagnosedSilenceableFailure transform::RewriteMatmulAsMmaSyncOp::applyToOne(
821  transform::TransformRewriter &rewriter, LinalgOp linalgOp,
823  transform::TransformState &state) {
824  bool fail = true;
825  // TODO: more robust detection of matmulOp, with transposes etc.
826  if (isa_and_nonnull<linalg::MatmulOp>(linalgOp.getOperation())) {
827  // Check to not let go the matmul with extended semantic, through this
828  // transform.
829  if (linalgOp.hasUserDefinedMaps()) {
830  return emitSilenceableError()
831  << "only matmul ops with non-extended semantics are supported";
832  }
833  Location loc = linalgOp.getLoc();
834  // TODO: more robust computation of laneId, for now assume a single warp.
835  Value laneId = rewriter.create<gpu::ThreadIdOp>(
836  loc, rewriter.getIndexType(), gpu::Dimension::x);
837  if (succeeded(MmaSyncBuilder(rewriter, loc, laneId).buildMmaSync(linalgOp)))
838  fail = false;
839  }
840 
841  if (fail) {
842  DiagnosedSilenceableFailure diag = emitSilenceableError()
843  << "unsupported target op: " << linalgOp;
844  diag.attachNote(linalgOp->getLoc()) << "target op";
845  return diag;
846  }
847 
848  rewriter.eraseOp(linalgOp);
850 }
851 
852 //===----------------------------------------------------------------------===//
853 // Hopper builders.
854 //===----------------------------------------------------------------------===//
855 
856 /// Helper to create the base Hopper-specific operations that are reused in
857 /// various other places.
860  : rewriter(rewriter), loc(loc) {}
861 
863  buildAndInitBarrierInSharedMemory(OpFoldResult numThreads);
864 
865  /// Create tma descriptor op to initiate transfer from global to shared
866  /// memory. This must be done before the launch op, on the host.
868  buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref,
869  gpu::LaunchOp launchOp);
870 
871  /// Build a tma load from global memory to shared memory using `barrier` to
872  /// synchronize. Return the number of bytes that will be transferred.
874  buildTmaAsyncLoad(TypedValue<nvgpu::TensorMapDescriptorType> globalDesc,
875  TypedValue<MemRefType> sharedMemref,
878  void buildBarrierArriveTx(TypedValue<nvgpu::MBarrierGroupType> barrier,
879  ArrayRef<OpFoldResult> sizes);
880 
881  /// If threadIdx.x == 0 does TMA request + wait, else just wait.
882  /// Return the operation that performs the transfer on thread0.
883  // TODO: In the future, don't hardcode to thread 0 but elect a leader.
884  SmallVector<Operation *> buildPredicateLoadsOnThread0(
886  ArrayRef<TypedValue<MemRefType>> sharedMemBuffers,
888 
889  void buildTryWaitParity(TypedValue<nvgpu::MBarrierGroupType> barrier);
890 
893 };
894 
896  ArrayRef<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescriptors,
897  ArrayRef<TypedValue<MemRefType>> sharedMemBuffers,
899  SmallVector<Operation *> loadOps;
900  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
901  Value tidx = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x);
902  Value cond =
903  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, tidx, zero);
904  // clang-format off
905  rewriter.create<scf::IfOp>(
906  /*location=*/loc,
907  /*conditional=*/cond,
908  /*thenBuilder=*/
909  [&](OpBuilder &lb, Location loc) {
911  sizes.reserve(globalDescriptors.size());
912  for (auto [desc, shmem] : llvm::zip_equal(
913  globalDescriptors, sharedMemBuffers)) {
914  OpFoldResult sz = buildTmaAsyncLoad(desc, shmem, barrier, loadOps);
915  sizes.push_back(sz);
916  }
917  // TODO: Note that cutlass predeclares the barrier arrive tx before the tma.async.load.
918  // This may or may not have perf implications.
919  buildBarrierArriveTx(barrier, sizes);
920  rewriter.create<scf::YieldOp>(loc);
921  },
922  /*elseBuilder=*/
923  [&](OpBuilder &lb, Location loc) {
924  // TODO: is this for no-thread divergence?
925  // Should we just yield the size and hoist?
926  buildBarrierArriveTx(barrier, getAsIndexOpFoldResult(rewriter.getContext(), 0));
927  rewriter.create<scf::YieldOp>(loc);
928  });
929  // clang-format on
930  return loadOps;
931 }
932 
935  b.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
936  // return b.getI64IntegerAttr(static_cast<int64_t>(kSharedMemorySpace));
937 }
938 
941  auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter);
942  Value barrier = rewriter.create<nvgpu::MBarrierCreateOp>(
943  loc,
944  nvgpu::MBarrierGroupType::get(rewriter.getContext(), sharedMemorySpace));
945  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
946  rewriter.create<nvgpu::MBarrierInitOp>(
947  loc, barrier, getValueOrCreateConstantIndexOp(rewriter, loc, numThreads),
948  zero, Value());
949  rewriter.create<gpu::BarrierOp>(loc);
950  return cast<TypedValue<nvgpu::MBarrierGroupType>>(barrier);
951 }
952 
955  gpu::LaunchOp launchOp) {
956  OpBuilder::InsertionGuard guard(rewriter);
957  rewriter.setInsertionPoint(launchOp);
958  Value unrankedMemRef = rewriter.create<memref::CastOp>(
959  loc,
960  UnrankedMemRefType::get(memref.getType().getElementType(),
961  memref.getType().getMemorySpace()),
962  memref);
963  SmallVector<OpFoldResult> mixedSizes =
964  memref::getMixedSizes(rewriter, loc, memref);
965  SmallVector<Value> sizes =
966  getValueOrCreateConstantIndexOp(rewriter, loc, mixedSizes);
967 
968  auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter);
969  Value desc = rewriter.create<nvgpu::TmaCreateDescriptorOp>(
970  loc,
972  rewriter.getContext(),
973  MemRefType::Builder(memref.getType())
974  .setMemorySpace(sharedMemorySpace),
975  TensorMapSwizzleKind::SWIZZLE_NONE,
976  TensorMapL2PromoKind::L2PROMO_NONE, TensorMapOOBKind::OOB_ZERO,
977  TensorMapInterleaveKind::INTERLEAVE_NONE),
978  unrankedMemRef, sizes);
979  return cast<TypedValue<nvgpu::TensorMapDescriptorType>>(desc);
980 }
981 
984  TypedValue<MemRefType> sharedMemref,
986  SmallVectorImpl<Operation *> &loadOps) {
987  MLIRContext *ctx = rewriter.getContext();
988  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
989  Operation *loadOp = rewriter.create<nvgpu::TmaAsyncLoadOp>(
990  loc, sharedMemref, barrier, globalDesc, ValueRange{zero, zero}, zero,
991  Value(), Value());
992  loadOps.push_back(loadOp);
993  auto mixedSizes = memref::getMixedSizes(rewriter, loc, sharedMemref);
994  SmallVector<AffineExpr> symbols(mixedSizes.size());
995  bindSymbolsList(ctx, llvm::MutableArrayRef{symbols});
996  AffineExpr prodExprInBytes =
997  computeProduct(ctx, symbols) *
998  (sharedMemref.getType().getElementTypeBitWidth() / 8);
999  auto res = affine::makeComposedFoldedAffineApply(rewriter, loc,
1000  prodExprInBytes, mixedSizes);
1001  return res;
1002 }
1003 
1006  ArrayRef<OpFoldResult> mixedSizes) {
1007  assert(!mixedSizes.empty() && "expecte non-empty sizes");
1008  MLIRContext *ctx = rewriter.getContext();
1009  SmallVector<AffineExpr> symbols(mixedSizes.size());
1010  bindSymbolsList(ctx, llvm::MutableArrayRef{symbols});
1011  AffineExpr sumExpr = computeSum(ctx, symbols);
1012  OpFoldResult size =
1013  affine::makeComposedFoldedAffineApply(rewriter, loc, sumExpr, mixedSizes);
1014  Value sizeVal = getValueOrCreateConstantIndexOp(rewriter, loc, size);
1015  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1016  rewriter.create<nvgpu::MBarrierArriveExpectTxOp>(loc, barrier, sizeVal, zero,
1017  Value());
1018 }
1019 
1022  Type i1 = rewriter.getI1Type();
1023  Value parity = rewriter.create<LLVM::ConstantOp>(loc, i1, 0);
1024  // 10M is an arbitrary, not too small or too big number to specify the number
1025  // of ticks before retry.
1026  // TODO: hoist this in a default dialect constant.
1027  Value ticksBeforeRetry =
1028  rewriter.create<arith::ConstantIndexOp>(loc, 10000000);
1029  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1030  rewriter.create<nvgpu::MBarrierTryWaitParityOp>(loc, barrier, parity,
1031  ticksBeforeRetry, zero);
1032 }
1033 
1034 //===----------------------------------------------------------------------===//
1035 // RewriteCopyAsTmaOp
1036 //===----------------------------------------------------------------------===//
1037 
1038 /// Helper to create the tma operations corresponding to `linalg::CopyOp`.
1039 struct CopyBuilder : public HopperBuilder {
1041  : HopperBuilder(rewriter, loc) {}
1042 
1044 };
1045 
1046 SmallVector<Operation *> CopyBuilder::rewrite(ArrayRef<Operation *> copyOps) {
1047  MLIRContext *ctx = rewriter.getContext();
1048  if (copyOps.empty())
1049  return SmallVector<Operation *>();
1050 
1051  auto launchOp = copyOps.front()->getParentOfType<gpu::LaunchOp>();
1052  assert(launchOp && "expected launch op");
1053 
1054  // 1. Init a barrier object in shared memory.
1055  OpBuilder::InsertionGuard g(rewriter);
1056  rewriter.setInsertionPoint(copyOps.front());
1057  AffineExpr bx, by, bz;
1058  bindSymbols(ctx, bx, by, bz);
1059  AffineExpr prod = computeProduct(ctx, ArrayRef<AffineExpr>{bx, by, bz});
1061  rewriter, loc, prod,
1062  ArrayRef<OpFoldResult>{launchOp.getBlockSizeX(), launchOp.getBlockSizeY(),
1063  launchOp.getBlockSizeZ()});
1064 
1066  buildAndInitBarrierInSharedMemory(numThreads);
1067 
1070  for (Operation *op : copyOps) {
1071  auto copyOp = cast<linalg::CopyOp>(op);
1072  auto inMemRef =
1073  cast<TypedValue<MemRefType>>(copyOp.getDpsInputOperand(0)->get());
1074  assert(inMemRef.getType().getRank() == 2 &&
1075  "expected in to be a 2D memref");
1076 
1077  // 2. Build global memory descriptor.
1079  buildGlobalMemRefDescriptor(inMemRef, launchOp);
1080  globalDescs.push_back(globalDesc);
1081 
1082  // 3. Shared memory and descriptor for the tmp array.
1083  auto shmem =
1084  cast<TypedValue<MemRefType>>(copyOp.getDpsInitOperand(0)->get());
1085  shmems.push_back(shmem);
1086  }
1087 
1088  // 4. Load in from global memory to shared memory using tma.
1089  OpBuilder::InsertionGuard g2(rewriter);
1090  rewriter.setInsertionPoint(copyOps.front());
1091  SmallVector<Operation *> results =
1092  buildPredicateLoadsOnThread0(globalDescs, shmems, barrier);
1093 
1094  // 5. Spin-loop until data is ready.
1095  buildTryWaitParity(barrier);
1096 
1097  // 6. Erase the ops that have now been rewritten.
1098  for (Operation *op : copyOps)
1099  rewriter.eraseOp(op);
1100 
1101  return results;
1102 }
1103 
1105 transform::RewriteCopyAsTmaOp::apply(transform::TransformRewriter &rewriter,
1106  transform::TransformResults &results,
1107  transform::TransformState &state) {
1108  auto payloadOps = state.getPayloadOps(getTarget());
1109  gpu::LaunchOp commonLaunchOp;
1110  Operation *firstOp, *failingOp;
1111  if (llvm::any_of(payloadOps, [&](Operation *op) {
1112  if (!commonLaunchOp) {
1113  commonLaunchOp = op->getParentOfType<gpu::LaunchOp>();
1114  firstOp = op;
1115  }
1116  auto fail = !op->getParentOfType<gpu::LaunchOp>() ||
1117  commonLaunchOp != op->getParentOfType<gpu::LaunchOp>() ||
1118  !isa<linalg::CopyOp>(op);
1119  if (fail)
1120  failingOp = op;
1121  return fail;
1122  })) {
1124  emitSilenceableError()
1125  << "target ops must be linalg::CopyOp nested under a common "
1126  "gpu.LaunchOp to be rewritten because the tma descriptors need to "
1127  "be created on the host.\nBut got: "
1128  << *firstOp << "\nand " << *failingOp;
1129  return diag;
1130  }
1131 
1132  // TODO: more robust detection of copy, with transposes etc.
1133  CopyBuilder(rewriter, getLoc()).rewrite(llvm::to_vector(payloadOps));
1134 
1136 }
1137 
1138 //===----------------------------------------------------------------------===//
1139 // Transform op registration
1140 //===----------------------------------------------------------------------===//
1141 
1142 namespace {
1143 class NVGPUTransformDialectExtension
1145  NVGPUTransformDialectExtension> {
1146 public:
1147  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NVGPUTransformDialectExtension)
1148 
1149  NVGPUTransformDialectExtension() {
1150  declareGeneratedDialect<arith::ArithDialect>();
1151  declareGeneratedDialect<affine::AffineDialect>();
1152  declareGeneratedDialect<nvgpu::NVGPUDialect>();
1153  declareGeneratedDialect<NVVM::NVVMDialect>();
1154  declareGeneratedDialect<vector::VectorDialect>();
1155  registerTransformOps<
1156 #define GET_OP_LIST
1157 #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
1158  >();
1159  }
1160 };
1161 } // namespace
1162 
1163 #define GET_OP_CLASSES
1164 #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
1165 
1167  registry.addExtensions<NVGPUTransformDialectExtension>();
1168 }
static constexpr int64_t kSharedMemorySpace
static std::string diag(const llvm::Value &value)
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
Definition: NVGPUDialect.h:40
static Attribute getSharedAddressSpaceAttribute(OpBuilder &b)
static bool hasDefaultMemorySpace(BaseMemRefType type)
Returns true if the given type has the default memory space.
static LogicalResult collectStage0PipeliningOps(scf::ForOp forOp, llvm::SmallPtrSet< Operation *, 16 > &ops)
Populate ops with the set of operations that belong to the stage 0 of the pipelined version of the gi...
static Operation * replaceOpWithPredicatedOp(RewriterBase &rewriter, Operation *op, Value predicate)
Hook for the loop pipeliner.
static bool isStoreToShared(Operation *op, Value v)
Returns true if the operation is storing the given value into shared memory.
static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn, ReduceFn reduceFn)
Helper functions to create customizable load and stores operations.
static std::tuple< DiagnosedSilenceableFailure, scf::ForOp > pipelineForSharedCopies(RewriterBase &rewriter, scf::ForOp forOp, int64_t depth, bool epiloguePeeling)
Applies loop pipelining with the given depth to the given loop so that copies into the shared memory ...
static bool hasSharedMemorySpace(BaseMemRefType type)
Returns true if the given type has the shared (workgroup) memory space.
static bool isLoadFromGlobalStoredToShared(Operation *op)
Returns true if the operation is a load from the default memory space the result of which is only sto...
static std::tuple< SmallVector< int64_t >, SmallVector< int64_t >, SmallVector< int64_t > > makeVectorShapes(ArrayRef< int64_t > lhs, ArrayRef< int64_t > rhs, ArrayRef< int64_t > res)
static void setAsyncWaitGroupsInFlight(OpBuilder &builder, Operation *op, scf::PipeliningOption::PipelinerPart part, unsigned iteration, unsigned depth)
Hook for the loop pipeliner that sets the "num groups in flight" attribute of async wait operations c...
static void getPipelineStages(scf::ForOp forOp, std::vector< std::pair< Operation *, unsigned >> &opsWithPipelineStages, unsigned depth, llvm::SmallPtrSetImpl< Operation * > &stage0Ops)
Hook for the loop pipeliner that populates ops with the stage information as follows:
static Value getValueLoadedFromGlobal(Operation *op)
Returns the value produced by a load from the default memory space.
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< VectorShape > vectorShape(Type type)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:331
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:921
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:102
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
FloatType getF32Type()
Definition: Builders.cpp:45
FloatType getF16Type()
Definition: Builders.cpp:41
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:55
IndexType getIndexType()
Definition: Builders.cpp:53
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:166
Builder & setMemorySpace(Attribute newMemorySpace)
Definition: BuiltinTypes.h:192
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:455
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Type conversion class.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isF16() const
Definition: Types.cpp:38
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
user_range getUsers() const
Definition: Value.h:218
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:197
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void push_back(Operation *op)
Appends an element to the list.
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
@ kGlobalMemorySpace
Global memory space identifier.
Definition: NVVMDialect.h:42
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1225
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)
Return the memref type that can be used to represent an mbarrier object.
void registerTransformDialectExtension(DialectRegistry &registry)
void createAsyncGroups(RewriterBase &rewriter, Operation *op, bool bypassL1)
Convert global->shared vector transfers to async device copies.
FailureOr< ForOp > pipelineForLoop(RewriterBase &rewriter, ForOp forOp, const PipeliningOption &options, bool *modifiedIR=nullptr)
Generate a pipelined version of the scf.for loop based on the schedule given as option.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
Definition: IndexingUtils.h:47
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
void populateNVGPUToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:621
int64_t computeSum(ArrayRef< int64_t > basis)
Self-explicit.
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
Definition: AffineExpr.h:330
Helper to create the tma operations corresponding to linalg::CopyOp.
SmallVector< Operation * > rewrite(ArrayRef< Operation * > copyOps)
CopyBuilder(RewriterBase &rewriter, Location loc)
Helper to create the base Hopper-specific operations that are reused in various other places.
OpFoldResult buildTmaAsyncLoad(TypedValue< nvgpu::TensorMapDescriptorType > globalDesc, TypedValue< MemRefType > sharedMemref, TypedValue< nvgpu::MBarrierGroupType > barrier, SmallVectorImpl< Operation * > &loadOps)
Build a tma load from global memory to shared memory using barrier to synchronize.
TypedValue< nvgpu::MBarrierGroupType > buildAndInitBarrierInSharedMemory(OpFoldResult numThreads)
void buildTryWaitParity(TypedValue< nvgpu::MBarrierGroupType > barrier)
RewriterBase & rewriter
TypedValue< nvgpu::TensorMapDescriptorType > buildGlobalMemRefDescriptor(TypedValue< MemRefType > memref, gpu::LaunchOp launchOp)
Create tma descriptor op to initiate transfer from global to shared memory.
SmallVector< Operation * > buildPredicateLoadsOnThread0(ArrayRef< TypedValue< nvgpu::TensorMapDescriptorType >> globalDescriptors, ArrayRef< TypedValue< MemRefType >> sharedMemBuffers, TypedValue< nvgpu::MBarrierGroupType > barrier)
If threadIdx.x == 0 does TMA request + wait, else just wait.
void buildBarrierArriveTx(TypedValue< nvgpu::MBarrierGroupType > barrier, ArrayRef< OpFoldResult > sizes)
HopperBuilder(RewriterBase &rewriter, Location loc)
Helper struct to provide a simple mapping from matmul operations to the corresponding mma....
std::function< SmallVector< RowColIndexing >(MLIRContext *)> IndexCalculator
MmaSyncBuilder(OpBuilder &b, Location loc, OpFoldResult laneId)
FailureOr< Operation * > buildMmaSync(LinalgOp linalgOp)
Create the mma.sync operation corresponding to linalgOp along with all the supporting load/store and ...
Helper struct to encode a pair of row/column indexings in the form of affine expressions.
AffineExpr col() const
RowColIndexing(AffineExpr row, AffineExpr col)
void print(llvm::raw_ostream &os) const
AffineExpr row() const
Options to dictate how loops should be pipelined.
Definition: Transforms.h:123