MLIR  20.0.0git
NVGPUTransformOps.cpp
Go to the documentation of this file.
1 //===- NVGPUTransformOps.cpp - Implementation of NVGPU transform ops ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
29 #include "mlir/IR/AffineExpr.h"
30 #include "mlir/IR/BuiltinTypes.h"
31 #include "mlir/IR/Value.h"
32 #include "llvm/ADT/ArrayRef.h"
33 
34 using namespace mlir;
35 using namespace mlir::linalg;
36 using namespace mlir::nvgpu;
37 using namespace mlir::NVVM;
38 using namespace mlir::transform;
39 
40 #define DEBUG_TYPE "nvgpu-transforms"
41 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
42 #define DBGSNL() (llvm::dbgs() << "\n")
43 #define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
44 
45 //===----------------------------------------------------------------------===//
46 // Apply...ConversionPatternsOp
47 //===----------------------------------------------------------------------===//
48 
49 void transform::ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns(
50  TypeConverter &typeConverter, RewritePatternSet &patterns) {
51  auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
52  /// device-side async tokens cannot be materialized in nvvm. We just
53  /// convert them to a dummy i32 type in order to easily drop them during
54  /// conversion.
56  llvmTypeConverter, [](gpu::AddressSpace space) -> unsigned {
57  switch (space) {
58  case gpu::AddressSpace::Global:
59  return static_cast<unsigned>(
61  case gpu::AddressSpace::Workgroup:
62  return static_cast<unsigned>(
64  case gpu::AddressSpace::Private:
65  return 0;
66  }
67  llvm_unreachable("unknown address space enum value");
68  return 0;
69  });
70  llvmTypeConverter.addConversion(
71  [&](nvgpu::DeviceAsyncTokenType type) -> Type {
72  return llvmTypeConverter.convertType(
73  IntegerType::get(type.getContext(), 32));
74  });
75  llvmTypeConverter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type {
76  return llvmTypeConverter.convertType(
77  IntegerType::get(type.getContext(), 64));
78  });
79  llvmTypeConverter.addConversion(
80  [&](nvgpu::WarpgroupAccumulatorType type) -> Type {
81  Type elemType = type.getFragmented().getElementType();
82  int64_t sizeM = type.getFragmented().getDimSize(0);
83  int64_t sizeN = type.getFragmented().getDimSize(1);
84 
85  unsigned numMembers;
86  if (elemType.isF32() || elemType.isInteger(32))
87  numMembers = sizeN / 2;
88  else if (elemType.isF16())
89  numMembers = sizeN / 4;
90  else
91  llvm_unreachable("unsupported type for warpgroup accumulator");
92 
93  SmallVector<Type> innerStructBody;
94  for (unsigned i = 0; i < numMembers; i++)
95  innerStructBody.push_back(elemType);
96  auto innerStructType = LLVM::LLVMStructType::getLiteral(
97  type.getContext(), innerStructBody);
98 
99  SmallVector<Type> structBody;
100  for (int i = 0; i < sizeM; i += kWgmmaSizeM)
101  structBody.push_back(innerStructType);
102 
103  auto convertedType =
104  LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
105  return llvmTypeConverter.convertType(convertedType);
106  });
107  llvmTypeConverter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type {
108  return llvmTypeConverter.convertType(
109  getMBarrierMemrefType(type.getContext(), type));
110  });
111  llvmTypeConverter.addConversion(
112  [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type {
113  return llvmTypeConverter.convertType(
114  IntegerType::get(type.getContext(), 64));
115  });
116  llvmTypeConverter.addConversion(
117  [&](nvgpu::TensorMapDescriptorType type) -> Type {
118  return LLVM::LLVMPointerType::get(type.getContext());
119  });
121 }
122 
123 LogicalResult
124 transform::ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter(
125  transform::TypeConverterBuilderOpInterface builder) {
126  if (builder.getTypeConverterType() != "LLVMTypeConverter")
127  return emitOpError("expected LLVMTypeConverter");
128  return success();
129 }
130 
131 //===---------------------------------------------------------------------===//
132 // CreateAsyncGroupsOp
133 //===---------------------------------------------------------------------===//
134 
135 void transform::CreateAsyncGroupsOp::getEffects(
137  transform::consumesHandle(getTargetMutable(), effects);
138  transform::producesHandle(getOperation()->getOpResults(), effects);
140 }
141 
142 DiagnosedSilenceableFailure transform::CreateAsyncGroupsOp::applyToOne(
143  TransformRewriter &rewriter, Operation *target,
144  ApplyToEachResultList &results, TransformState &state) {
145  nvgpu::createAsyncGroups(rewriter, target, getBypassL1());
146  results.push_back(target);
148 }
149 
150 //===----------------------------------------------------------------------===//
151 // PipelineSharedMemoryCopiesOp
152 //===----------------------------------------------------------------------===//
153 
154 /// Returns true if the given type has the default memory space.
156  return !type.getMemorySpace() || type.getMemorySpaceAsInt() == 0;
157 }
158 
159 /// Returns true if the given type has the shared (workgroup) memory space.
161  auto space =
162  dyn_cast_if_present<gpu::AddressSpaceAttr>(type.getMemorySpace());
163  return space &&
164  space.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
165 }
166 
167 /// Returns the value produced by a load from the default memory space. Returns
168 /// null if the operation is not such a load.
170  // TODO: consider an interface or leveraging the memory effects interface.
171  auto load = dyn_cast<vector::TransferReadOp>(op);
172  if (!load)
173  return nullptr;
174 
175  auto loadType = dyn_cast<MemRefType>(load.getSource().getType());
176  if (!loadType || !hasDefaultMemorySpace(loadType))
177  return nullptr;
178  return load;
179 }
180 
181 /// Returns true if the operation is storing the given value into shared memory.
182 static bool isStoreToShared(Operation *op, Value v) {
183  // TOD: consider an interface or leveraging the memory effects interface.
184  auto store = dyn_cast<vector::TransferWriteOp>(op);
185  if (!store || store.getVector() != v)
186  return false;
187 
188  auto storeType = dyn_cast<MemRefType>(store.getSource().getType());
189  return storeType || hasSharedMemorySpace(storeType);
190 }
191 
192 /// Returns true if the operation is a load from the default memory space the
193 /// result of which is only stored into the shared memory space.
195  Value loaded = getValueLoadedFromGlobal(op);
196  if (!loaded || !loaded.hasOneUse())
197  return false;
198 
199  return isStoreToShared(*loaded.getUsers().begin(), loaded);
200 }
201 
202 /// Populate `ops` with the set of operations that belong to the stage 0 of the
203 /// pipelined version of the given loop when pipelining copies to shared memory.
204 /// Specifically, this collects:
205 ///
206 /// 1. all loads from global memory, both sync and async;
207 /// 2. the barriers for async loads.
208 ///
209 /// In particular, barriers are omitted if they do not dominate at least one
210 /// async load for which there is not yet a barrier.
211 static LogicalResult
212 collectStage0PipeliningOps(scf::ForOp forOp,
214 
216  for (Operation &op : *forOp.getBody()) {
217  // Bail on nested ops for now.
218  if (op.getNumRegions() > 0)
219  return failure();
220 
221  if (isa<gpu::BarrierOp>(op)) {
222  barriers.insert(&op);
223  continue;
224  }
225 
226  if (isa<nvgpu::DeviceAsyncCopyOp, nvgpu::DeviceAsyncCreateGroupOp>(op)) {
227  ops.insert(&op);
228  ops.insert(std::make_move_iterator(barriers.begin()),
229  std::make_move_iterator(barriers.end()));
230  assert(barriers.empty() &&
231  "expected to have moved the barriers into another set");
232  continue;
233  }
234 
236  ops.insert(&op);
237  continue;
238  }
239  }
240 
241  return success();
242 }
243 
244 /// Hook for the loop pipeliner that sets the "num groups in flight" attribute
245 /// of async wait operations corresponding to pipelined shared memory copies.
246 // TODO: this currently assumes that there are no groups that could be in flight
247 // in the existing code.
248 static void
251  unsigned iteration, unsigned depth) {
252  // Based on the order of copies within the loop we need to set the number
253  // of copies in flight, unless it is already set.
254  auto waitOp = dyn_cast<nvgpu::DeviceAsyncWaitOp>(op);
255  if (!waitOp || waitOp.getNumGroups())
256  return;
257 
258  int numGroupInFlight = 0;
261  numGroupInFlight = depth - 1;
262  } else {
263  // By construction there should be no wait op in the prologue as all the
264  // wait should be in the last stage.
266  // Based on the schedule we pick we know how many groups are in flight for
267  // each iteration of the epilogue.
268  numGroupInFlight = depth - 1 - iteration;
269  }
270  waitOp.setNumGroups(numGroupInFlight);
271 }
272 
273 /// Hook for the loop pipeliner that populates `ops` with the stage information
274 /// as follows:
275 ///
276 /// - operations in `stage0Ops` (typically loads from global memory and
277 /// related barriers) are at stage 0;
278 /// - operations in the backward slice of any stage0Ops are all at stage 0;
279 /// - other operations are at stage `depth`;
280 /// - the internal order of the pipelined loop has ops at stage `depth` first,
281 /// then those at stage 0, with relative order within each group preserved.
282 ///
283 static void getPipelineStages(
284  scf::ForOp forOp,
285  std::vector<std::pair<Operation *, unsigned>> &opsWithPipelineStages,
286  unsigned depth, llvm::SmallPtrSetImpl<Operation *> &stage0Ops) {
287  SetVector<Operation *> dependencies;
288  BackwardSliceOptions options([&](Operation *visited) {
289  return visited->getBlock() == forOp.getBody();
290  });
291  options.inclusive = true;
292  for (Operation &op : forOp.getBody()->getOperations()) {
293  if (stage0Ops.contains(&op))
294  getBackwardSlice(&op, &dependencies, options);
295  }
296 
297  for (Operation &op : forOp.getBody()->getOperations()) {
298  if (!dependencies.contains(&op) && !isa<scf::YieldOp>(op))
299  opsWithPipelineStages.emplace_back(&op, depth);
300  }
301  for (Operation &op : forOp.getBody()->getOperations()) {
302  if (dependencies.contains(&op))
303  opsWithPipelineStages.emplace_back(&op, 0);
304  }
305 }
306 
307 /// Hook for the loop pipeliner. Replaces op with a predicated version and
308 /// returns the resulting operation. Returns the original op if the predication
309 /// isn't necessary for the given op. Returns null if predication is needed but
310 /// not supported.
312  Operation *op, Value predicate) {
313  // Some operations may be fine to execute "speculatively" more times than the
314  // original number of iterations, in particular side-effect free operations
315  // and barriers, even if they cannot be predicated.
316  if (isMemoryEffectFree(op) ||
317  isa<gpu::BarrierOp, nvgpu::DeviceAsyncCreateGroupOp,
318  nvgpu::DeviceAsyncWaitOp>(op)) {
319  return op;
320  }
321 
322  // Otherwise, only async copies can currently be predicated.
323  auto asyncCopyOp = dyn_cast<nvgpu::DeviceAsyncCopyOp>(op);
324  if (!asyncCopyOp)
325  return nullptr;
326 
327  // Create srcElement Value based on `predicate`. The next lines generate
328  // the following code:
329  //
330  // srcElement = (pred) ? prevSrcElements : 0;
331  //
332  Location loc = asyncCopyOp->getLoc();
333  Value dstElements =
334  rewriter.create<arith::ConstantOp>(loc, asyncCopyOp.getDstElementsAttr());
335  Value originalSrcElement =
336  asyncCopyOp.getSrcElements() ? asyncCopyOp.getSrcElements() : dstElements;
337  Value c0Index = rewriter.create<arith::ConstantIndexOp>(loc, 0);
338  auto srcElements = rewriter.create<arith::SelectOp>(
339  loc, predicate, originalSrcElement, c0Index);
340  auto asyncCopyZeroFillOp = rewriter.create<nvgpu::DeviceAsyncCopyOp>(
341  loc, nvgpu::DeviceAsyncTokenType::get(asyncCopyOp.getContext()),
342  asyncCopyOp.getDst(), asyncCopyOp.getDstIndices(), asyncCopyOp.getSrc(),
343  asyncCopyOp.getSrcIndices(), asyncCopyOp.getDstElements(), srcElements,
344  UnitAttr());
345  rewriter.replaceOp(asyncCopyOp, asyncCopyZeroFillOp);
346  return asyncCopyZeroFillOp;
347 }
348 
349 /// Applies loop pipelining with the given depth to the given loop so that
350 /// copies into the shared memory are pipelined. Doesn't affect other loops.
351 /// Returns a pair containing the error state and the pipelined op, the latter
352 /// being null in case of any failure. The error state contains a definite error
353 /// if the IR has been modified and a silenceable error otherwise.
354 static std::tuple<DiagnosedSilenceableFailure, scf::ForOp>
355 pipelineForSharedCopies(RewriterBase &rewriter, scf::ForOp forOp, int64_t depth,
356  bool epiloguePeeling) {
358  if (failed(collectStage0PipeliningOps(forOp, stage0Ops))) {
359  return std::make_tuple(
360  emitSilenceableFailure(forOp, "cannot find stage 0 ops for pipelining"),
361  scf::ForOp());
362  }
363  if (stage0Ops.empty()) {
364  return std::make_tuple(
365  emitSilenceableFailure(forOp, "no shared memory copy"), scf::ForOp());
366  }
367 
369  unsigned maxDepth = depth;
370  auto setAnnotation = [&](Operation *op,
372  unsigned iteration) {
373  return setAsyncWaitGroupsInFlight(rewriter, op, part, iteration, maxDepth);
374  };
375  options.getScheduleFn =
376  [&](scf::ForOp schedulingFor,
377  std::vector<std::pair<Operation *, unsigned>> &ops) {
378  if (schedulingFor != forOp)
379  return;
380  return getPipelineStages(forOp, ops, maxDepth, stage0Ops);
381  };
382  options.annotateFn = setAnnotation;
383  if (!epiloguePeeling) {
384  options.peelEpilogue = false;
385  options.predicateFn = replaceOpWithPredicatedOp;
386  }
387 
388  OpBuilder::InsertionGuard guard(rewriter);
389  rewriter.setInsertionPoint(forOp);
390  bool modifiedIR;
391  FailureOr<scf::ForOp> maybePipelined =
392  pipelineForLoop(rewriter, forOp, options, &modifiedIR);
393  if (succeeded(maybePipelined)) {
394  return std::make_tuple(DiagnosedSilenceableFailure::success(),
395  *maybePipelined);
396  }
397  return std::make_tuple(
398  modifiedIR
400  : emitSilenceableFailure(forOp, "pipelining preconditions failed"),
401  scf::ForOp());
402 }
403 
404 DiagnosedSilenceableFailure PipelineSharedMemoryCopiesOp::applyToOne(
405  TransformRewriter &rewriter, scf::ForOp forOp,
406  ApplyToEachResultList &results, TransformState &state) {
407  auto [diag, pipelined] = pipelineForSharedCopies(
408  rewriter, forOp, static_cast<int64_t>(getDepth()), getPeelEpilogue());
409  if (diag.succeeded()) {
410  results.push_back(pipelined);
412  }
413  if (diag.isDefiniteFailure()) {
414  auto diag = emitDefiniteFailure("irreversible pipelining failure");
415  if (!getPeelEpilogue()) {
416  diag.attachNote(forOp->getLoc()) << "couldn't predicate?";
417  diag.attachNote(getLoc()) << "try setting " << getPeelEpilogueAttrName();
418  }
419  return diag;
420  }
421 
422  return std::move(diag);
423 }
424 
425 //===----------------------------------------------------------------------===//
426 // RewriteMatmulAsMmaSyncOp
427 //===----------------------------------------------------------------------===//
428 
429 /// Helper struct to encode a pair of row/column indexings in the form of
430 /// affine expressions.
431 struct RowColIndexing : private std::pair<AffineExpr, AffineExpr> {
433  : std::pair<AffineExpr, AffineExpr>(row, col) {}
434 
435  AffineExpr row() const { return first; };
436  AffineExpr col() const { return second; };
437 
438  void print(llvm::raw_ostream &os) const {
439  os << "- indexing: " << first << ", " << second;
440  }
441 };
442 
443 /// Helper struct to provide a simple mapping from matmul operations to the
444 /// corresponding mma.sync operation. This is constrained to the case where the
445 /// matmul matches the mma.sync operation 1-1.
448  : b(b), loc(loc), laneId(laneId) {}
449 
451  std::function<SmallVector<RowColIndexing>(MLIRContext *)>;
452 
453  /// Create the mma.sync operation corresponding to `linalgOp` along with all
454  /// the supporting load/store and vector operations.
455  FailureOr<Operation *> buildMmaSync(LinalgOp linalgOp);
456 
457 private:
458  struct MmaSyncInfo {
459  std::tuple<IndexCalculator, IndexCalculator, IndexCalculator> indexFns;
460  std::tuple<SmallVector<int64_t>, SmallVector<int64_t>, SmallVector<int64_t>>
461  vectorShapes;
462  SmallVector<int64_t> mmaShape;
463  bool tf32Enabled;
464  };
465 
466  /// Return the specific index calculator for the given `linalgOp` or failure
467  /// if the op is not supported. This is the toplevel switch that should just
468  /// be Tablegen'd in the future.
469  FailureOr<MmaSyncInfo> getIndexCalculators(ArrayRef<int64_t> opShape,
470  TypeRange elementalTypes);
471 
472  //===--------------------------------------------------------------------===//
473  // Instruction-specific row, column indexing expression builders.
474  // These should all be declaratively specified via Tablegen in the future.
475  // The Tablegen specification should be as straightforward as possible to
476  // only model the existing size and type combinations.
477  //===--------------------------------------------------------------------===//
478  //
479  // TODO: Tablegen all this.
480  //===--------------------------------------------------------------------===//
481  // m16n8k4 tf32 case.
482  //===--------------------------------------------------------------------===//
483  /// From the NVIDIA doc:
484  /// groupID = %laneid >> 2
485  /// threadIDInGroup = %laneid % 4
486  /// row = groupID for a0
487  /// groupID + 8 for a1
488  /// col = threadIDInGroup
489  static SmallVector<RowColIndexing> m16n8k4tf32Lhs(MLIRContext *ctx) {
490  auto dim = getAffineDimExpr(0, ctx);
491  AffineExpr groupID = dim.floorDiv(4);
492  AffineExpr threadIDInGroup = dim % 4;
493  return {RowColIndexing{groupID, threadIDInGroup},
494  RowColIndexing{groupID + 8, threadIDInGroup}};
495  }
496 
497  /// From the NVIDIA doc:
498  /// groupID = %laneid >> 2
499  /// threadIDInGroup = %laneid % 4
500  /// row = threadIDInGroup
501  /// col = groupID
502  static SmallVector<RowColIndexing> m16n8k4tf32Rhs(MLIRContext *ctx) {
503  auto dim = getAffineDimExpr(0, ctx);
504  AffineExpr groupID = dim.floorDiv(4);
505  AffineExpr threadIDInGroup = dim % 4;
506  return {RowColIndexing{threadIDInGroup, groupID}};
507  }
508 
509  /// From the NVIDIA doc:
510  /// groupID = %laneid >> 2
511  /// threadIDInGroup = %laneid % 4
512  /// row = groupID for c0 and c1
513  /// groupID + 8 for c2 and c3
514  /// col = (threadIDInGroup * 2) + (i & 0x1) for ci where i = {0,..,3}
515  static SmallVector<RowColIndexing> m16n8k4tf32Res(MLIRContext *ctx) {
516  auto dim = getAffineDimExpr(0, ctx);
517  AffineExpr groupID = dim.floorDiv(4);
518  AffineExpr threadIDInGroup = dim % 4;
519  return {RowColIndexing{groupID, threadIDInGroup * 2 + 0},
520  RowColIndexing{groupID, threadIDInGroup * 2 + 1},
521  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0},
522  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}};
523  }
524 
525  //===--------------------------------------------------------------------===//
526  // m16n8k16 f16 case.
527  //===--------------------------------------------------------------------===//
528  /// From the NVIDIA doc:
529  /// groupID = %laneid >> 2
530  /// threadIDInGroup = %laneid % 4
531  ///
532  /// row = groupID for ai where 0 <= i < 2 || 4 <= i < 6
533  /// groupID + 8 Otherwise
534  ///
535  /// col = (threadIDInGroup * 2) + (i & 0x1) for ai where i < 4
536  /// (threadIDInGroup * 2) + (i & 0x1) + 8 for ai where i >= 4
537  static SmallVector<RowColIndexing> m16n8k16f16Lhs(MLIRContext *ctx) {
538  auto dim = getAffineDimExpr(0, ctx);
539  AffineExpr groupID = dim.floorDiv(4);
540  AffineExpr threadIDInGroup = dim % 4;
541  // clang-format off
542  return {
543  RowColIndexing{groupID, threadIDInGroup * 2 + 0}, // i == 0
544  RowColIndexing{groupID, threadIDInGroup * 2 + 1}, // i == 1
545  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, // i == 2
546  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}, // i == 3
547  RowColIndexing{groupID, threadIDInGroup * 2 + 0 + 8}, // i == 4
548  RowColIndexing{groupID, threadIDInGroup * 2 + 1 + 8}, // i == 5
549  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0 + 8}, // i == 6
550  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1 + 8} // i == 7
551  };
552  // clang-format on
553  }
554 
555  /// From the NVIDIA doc:
556  /// groupID = %laneid >> 2
557  /// threadIDInGroup = %laneid % 4
558  ///
559  /// row = (threadIDInGroup * 2) + (i & 0x1) for bi where i < 2
560  /// (threadIDInGroup * 2) + (i & 0x1) + 8 for bi where i >= 2
561  ///
562  /// col = groupID
563  static SmallVector<RowColIndexing> m16n8k16f16Rhs(MLIRContext *ctx) {
564  auto dim = getAffineDimExpr(0, ctx);
565  AffineExpr groupID = dim.floorDiv(4);
566  AffineExpr threadIDInGroup = dim % 4;
567  // clang-format off
568  return {
569  RowColIndexing{threadIDInGroup * 2 + 0, groupID}, // i == 0
570  RowColIndexing{threadIDInGroup * 2 + 1, groupID}, // i == 1
571  RowColIndexing{threadIDInGroup * 2 + 0 + 8, groupID}, // i == 2
572  RowColIndexing{threadIDInGroup * 2 + 1 + 8, groupID} // i == 3
573  };
574  // clang-format on
575  }
576 
577  /// From the NVIDIA doc:
578  /// groupID = %laneid >> 2
579  /// threadIDInGroup = %laneid % 4
580  ///
581  /// row = groupID for ci where i < 2
582  /// groupID + 8 for ci where i >= 2
583  ///
584  /// col = (threadIDInGroup * 2) + (i & 0x1) for ci where i = {0,..,3}
585  static SmallVector<RowColIndexing> m16n8k16f16Res(MLIRContext *ctx) {
586  auto dim = getAffineDimExpr(0, ctx);
587  AffineExpr groupID = dim.floorDiv(4);
588  AffineExpr threadIDInGroup = dim % 4;
589  // clang-format off
590  return {
591  RowColIndexing{groupID, threadIDInGroup * 2 + 0}, // i == 0
592  RowColIndexing{groupID, threadIDInGroup * 2 + 1}, // i == 1
593  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, // i == 2
594  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1} // i == 3
595  };
596  // clang-format on
597  }
598 
599  //===--------------------------------------------------------------------===//
600  /// Helper functions to create customizable load and stores operations. The
601  /// specific shapes of each MMA instruction are passed via the
602  /// IndexCalculator callback.
603  //===--------------------------------------------------------------------===//
604  /// Build a list of memref.load operations indexed at `(row, col)` indices
605  /// that make sense for a particular MMA instruction and specified via the
606  /// IndexCalculator callback.
607  SmallVector<Value> buildMemRefLoads(OpBuilder &b, Location loc,
608  OpFoldResult laneId, Value memref,
609  const IndexCalculator &indexFn);
610 
611  /// Perform a distributed load of a vector operand of `vectorShape` for a
612  /// particular MMA instruction whose `(row, col)` indices are specified via
613  /// the IndexCalculator callback. Each `laneId` loads the subportion of the
614  /// data that makes sense for the particular MMA operation.
615  /// The `vectorShape` matches existing NVGPU dialect op specification but
616  /// could also be flattened in the future if needed for simplification.
617  Value buildMmaSyncMemRefLoadOperand(OpBuilder &b, Location loc,
618  OpFoldResult laneId, Value memref,
619  IndexCalculator indexFn,
621 
622  /// Build a list of memref.store operations indexed at `(row, col)` indices
623  /// that make sense for a particular MMA instruction and specified via the
624  /// IndexCalculator callback.
625  SmallVector<Operation *> buildMemRefStores(OpBuilder &b, Location loc,
626  ValueRange toStore,
627  OpFoldResult laneId, Value memref,
628  const IndexCalculator &indexFn);
629 
630  /// Perform a distributed store of a vector operand of `vectorShape` for a
631  /// particular MMA instruction whose `(row, col)` indices are specified via
632  /// the IndexCalculator callback. Each `laneId` loads the subportion of the
633  /// data that makes sense for the particular MMA operation.
634  /// The `vectorShape` matches existing NVGPU dialect op specification but
635  /// could also be flattened in the future if needed for simplification.
636  SmallVector<Operation *> buildMmaSyncMemRefStoreOperand(
637  OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId,
638  Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape);
639 
640  OpBuilder &b;
641  Location loc;
642  OpFoldResult laneId;
643 };
644 
645 //===--------------------------------------------------------------------===//
646 /// Helper functions to create customizable load and stores operations. The
647 /// specific shapes of each MMA instruction are passed via the
648 /// IndexCalculator callback.
649 //===--------------------------------------------------------------------===//
650 
651 template <typename ApplyFn, typename ReduceFn>
652 static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn,
653  ReduceFn reduceFn) {
654  VectorType vectorType = cast<VectorType>(vector.getType());
655  auto vectorShape = vectorType.getShape();
656  auto strides = computeStrides(vectorShape);
657  for (int64_t idx = 0, e = vectorShape[0] * strides[0]; idx < e; ++idx) {
658  auto indices = delinearize(idx, strides);
659  reduceFn(applyFn(vector, idx, indices), idx, indices);
660  }
661 }
662 
664 MmaSyncBuilder::buildMemRefLoads(OpBuilder &b, Location loc,
665  OpFoldResult laneId, Value memref,
666  const IndexCalculator &indexFn) {
667  auto aff = [&](AffineExpr e) {
668  return affine::makeComposedFoldedAffineApply(b, loc, e, laneId);
669  };
670  SmallVector<Value> res;
671  SmallVector<RowColIndexing> indexings = indexFn(b.getContext());
672  for (auto indexing : indexings) {
673  Value row = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.row()));
674  Value col = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.col()));
675  auto load = b.create<memref::LoadOp>(loc, memref, ValueRange{row, col});
676  res.push_back(load);
677  }
678  return res;
679 }
680 
681 Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand(
682  OpBuilder &b, Location loc, OpFoldResult laneId, Value memref,
683  IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) {
684  auto loads = buildMemRefLoads(b, loc, laneId, memref, std::move(indexFn));
685 
686  Type elementType = getElementTypeOrSelf(memref.getType());
687  auto vt = VectorType::get(vectorShape, elementType);
688  Value res = b.create<vector::SplatOp>(loc, vt, loads[0]);
690  res,
691  /*applyFn=*/
692  [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
693  return loads[linearIdx];
694  },
695  /*reduceFn=*/
696  [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
697  res = b.create<vector::InsertOp>(loc, v, res, indices);
698  });
699 
700  return res;
701 }
702 
703 SmallVector<Operation *> MmaSyncBuilder::buildMemRefStores(
704  OpBuilder &b, Location loc, ValueRange toStore, OpFoldResult laneId,
705  Value memref, const IndexCalculator &indexFn) {
706  auto aff = [&](AffineExpr e) {
707  return affine::makeComposedFoldedAffineApply(b, loc, e, laneId);
708  };
710  for (auto [indexing, val] :
711  llvm::zip_equal(indexFn(b.getContext()), toStore)) {
712  Value row = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.row()));
713  Value col = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.col()));
714  Operation *store =
715  b.create<memref::StoreOp>(loc, val, memref, ValueRange{row, col});
716  res.push_back(store);
717  }
718  return res;
719 }
720 
721 SmallVector<Operation *> MmaSyncBuilder::buildMmaSyncMemRefStoreOperand(
722  OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId,
723  Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) {
724  SmallVector<Value> toStore;
725  toStore.reserve(32);
727  vectorToStore,
728  /*applyFn=*/
729  [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
730  return b.create<vector::ExtractOp>(loc, vectorToStore, indices);
731  },
732  /*reduceFn=*/
733  [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
734  toStore.push_back(v);
735  });
736  return buildMemRefStores(b, loc, toStore, laneId, memref, std::move(indexFn));
737 }
738 
739 static std::tuple<SmallVector<int64_t>, SmallVector<int64_t>,
742  ArrayRef<int64_t> res) {
743  SmallVector<int64_t> vlhs{lhs};
744  SmallVector<int64_t> vrhs{rhs};
745  SmallVector<int64_t> vres{res};
746  return std::make_tuple(vlhs, vrhs, vres);
747 }
748 
749 FailureOr<MmaSyncBuilder::MmaSyncInfo>
750 MmaSyncBuilder::getIndexCalculators(ArrayRef<int64_t> opShape,
751  TypeRange elementalTypes) {
752  // TODO: Tablegen all this.
753  Type f16 = b.getF16Type();
754  Type f32 = b.getF32Type();
755  if (opShape == ArrayRef<int64_t>{16, 8, 4} &&
756  elementalTypes == TypeRange{f32, f32, f32}) {
757  return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k4tf32Lhs,
758  &MmaSyncBuilder::m16n8k4tf32Rhs,
759  &MmaSyncBuilder::m16n8k4tf32Res),
760  makeVectorShapes({2, 1}, {1, 1}, {2, 2}),
761  SmallVector<int64_t>{opShape},
762  /*tf32Enabled=*/true};
763  }
764  // This is the version with f16 accumulation.
765  // TODO: version with f32 accumulation.
766  if (opShape == ArrayRef<int64_t>{16, 8, 16} &&
767  elementalTypes == TypeRange{f16, f16, f16}) {
768  return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k16f16Lhs,
769  &MmaSyncBuilder::m16n8k16f16Rhs,
770  &MmaSyncBuilder::m16n8k16f16Res),
771  makeVectorShapes({4, 2}, {2, 2}, {2, 2}),
772  SmallVector<int64_t>{opShape},
773  /*tf32Enabled=*/false};
774  }
775  return failure();
776 }
777 
778 FailureOr<Operation *> MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) {
779  Value lhsMemRef = linalgOp.getDpsInputOperand(0)->get();
780  Value rhsMemRef = linalgOp.getDpsInputOperand(1)->get();
781  Value resMemRef = linalgOp.getDpsInitOperand(0)->get();
782  assert(cast<MemRefType>(lhsMemRef.getType()).getRank() == 2 &&
783  "expected lhs to be a 2D memref");
784  assert(cast<MemRefType>(rhsMemRef.getType()).getRank() == 2 &&
785  "expected rhs to be a 2D memref");
786  assert(cast<MemRefType>(resMemRef.getType()).getRank() == 2 &&
787  "expected res to be a 2D memref");
788 
789  int64_t m = cast<MemRefType>(lhsMemRef.getType()).getShape()[0];
790  int64_t n = cast<MemRefType>(rhsMemRef.getType()).getShape()[1];
791  int64_t k = cast<MemRefType>(lhsMemRef.getType()).getShape()[1];
792  Type lhsType = getElementTypeOrSelf(lhsMemRef.getType());
793  Type rhsType = getElementTypeOrSelf(rhsMemRef.getType());
794  Type resType = getElementTypeOrSelf(resMemRef.getType());
795 
796  FailureOr<MmaSyncInfo> maybeInfo =
797  getIndexCalculators({m, n, k}, {lhsType, rhsType, resType});
798  if (failed(maybeInfo))
799  return failure();
800 
801  MmaSyncInfo info = *maybeInfo;
802  auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns;
803  auto [lhsShape, rhsShape, resShape] = info.vectorShapes;
804  Value lhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, lhsMemRef,
805  lhsIndexFn, lhsShape);
806  Value rhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, rhsMemRef,
807  rhsIndexFn, rhsShape);
808  Value res = buildMmaSyncMemRefLoadOperand(b, loc, laneId, resMemRef,
809  resIndexFn, resShape);
810  res = b.create<nvgpu::MmaSyncOp>(loc, lhs, rhs, res, info.mmaShape,
811  info.tf32Enabled);
812  buildMmaSyncMemRefStoreOperand(b, loc, res, laneId, resMemRef, resIndexFn,
813  resShape);
814  return res.getDefiningOp();
815 }
816 
817 DiagnosedSilenceableFailure transform::RewriteMatmulAsMmaSyncOp::applyToOne(
818  transform::TransformRewriter &rewriter, LinalgOp linalgOp,
820  transform::TransformState &state) {
821  bool fail = true;
822  // TODO: more robust detection of matmulOp, with transposes etc.
823  if (isa_and_nonnull<linalg::MatmulOp>(linalgOp.getOperation())) {
824  // Check to not let go the matmul with extended semantic, through this
825  // transform.
826  if (linalgOp.hasUserDefinedMaps()) {
827  return emitSilenceableError()
828  << "only matmul ops with non-extended semantics are supported";
829  }
830  Location loc = linalgOp.getLoc();
831  // TODO: more robust computation of laneId, for now assume a single warp.
832  Value laneId = rewriter.create<gpu::ThreadIdOp>(
833  loc, rewriter.getIndexType(), gpu::Dimension::x);
834  if (succeeded(MmaSyncBuilder(rewriter, loc, laneId).buildMmaSync(linalgOp)))
835  fail = false;
836  }
837 
838  if (fail) {
839  DiagnosedSilenceableFailure diag = emitSilenceableError()
840  << "unsupported target op: " << linalgOp;
841  diag.attachNote(linalgOp->getLoc()) << "target op";
842  return diag;
843  }
844 
845  rewriter.eraseOp(linalgOp);
847 }
848 
849 //===----------------------------------------------------------------------===//
850 // Hopper builders.
851 //===----------------------------------------------------------------------===//
852 
853 /// Helper to create the base Hopper-specific operations that are reused in
854 /// various other places.
857  : rewriter(rewriter), loc(loc) {}
858 
860  buildAndInitBarrierInSharedMemory(OpFoldResult numThreads);
861 
862  /// Create tma descriptor op to initiate transfer from global to shared
863  /// memory. This must be done before the launch op, on the host.
865  buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref,
866  gpu::LaunchOp launchOp);
867 
868  /// Build a tma load from global memory to shared memory using `barrier` to
869  /// synchronize. Return the number of bytes that will be transferred.
871  buildTmaAsyncLoad(TypedValue<nvgpu::TensorMapDescriptorType> globalDesc,
872  TypedValue<MemRefType> sharedMemref,
875  void buildBarrierArriveTx(TypedValue<nvgpu::MBarrierGroupType> barrier,
876  ArrayRef<OpFoldResult> sizes);
877 
878  /// If threadIdx.x == 0 does TMA request + wait, else just wait.
879  /// Return the operation that performs the transfer on thread0.
880  // TODO: In the future, don't hardcode to thread 0 but elect a leader.
881  SmallVector<Operation *> buildPredicateLoadsOnThread0(
883  ArrayRef<TypedValue<MemRefType>> sharedMemBuffers,
885 
886  void buildTryWaitParity(TypedValue<nvgpu::MBarrierGroupType> barrier);
887 
890 };
891 
893  ArrayRef<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescriptors,
894  ArrayRef<TypedValue<MemRefType>> sharedMemBuffers,
896  SmallVector<Operation *> loadOps;
897  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
898  Value tidx = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x);
899  Value cond =
900  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, tidx, zero);
901  // clang-format off
902  rewriter.create<scf::IfOp>(
903  /*location=*/loc,
904  /*conditional=*/cond,
905  /*thenBuilder=*/
906  [&](OpBuilder &lb, Location loc) {
908  sizes.reserve(globalDescriptors.size());
909  for (auto [desc, shmem] : llvm::zip_equal(
910  globalDescriptors, sharedMemBuffers)) {
911  OpFoldResult sz = buildTmaAsyncLoad(desc, shmem, barrier, loadOps);
912  sizes.push_back(sz);
913  }
914  // TODO: Note that cutlass predeclares the barrier arrive tx before the tma.async.load.
915  // This may or may not have perf implications.
916  buildBarrierArriveTx(barrier, sizes);
917  rewriter.create<scf::YieldOp>(loc);
918  },
919  /*elseBuilder=*/
920  [&](OpBuilder &lb, Location loc) {
921  // TODO: is this for no-thread divergence?
922  // Should we just yield the size and hoist?
923  buildBarrierArriveTx(barrier, getAsIndexOpFoldResult(rewriter.getContext(), 0));
924  rewriter.create<scf::YieldOp>(loc);
925  });
926  // clang-format on
927  return loadOps;
928 }
929 
932  b.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
933  // return b.getI64IntegerAttr(static_cast<int64_t>(kSharedMemorySpace));
934 }
935 
938  auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter);
939  Value barrier = rewriter.create<nvgpu::MBarrierCreateOp>(
940  loc,
941  nvgpu::MBarrierGroupType::get(rewriter.getContext(), sharedMemorySpace));
942  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
943  rewriter.create<nvgpu::MBarrierInitOp>(
944  loc, barrier, getValueOrCreateConstantIndexOp(rewriter, loc, numThreads),
945  zero, Value());
946  rewriter.create<gpu::BarrierOp>(loc);
947  return cast<TypedValue<nvgpu::MBarrierGroupType>>(barrier);
948 }
949 
952  gpu::LaunchOp launchOp) {
953  OpBuilder::InsertionGuard guard(rewriter);
954  rewriter.setInsertionPoint(launchOp);
955  Value unrankedMemRef = rewriter.create<memref::CastOp>(
956  loc,
957  UnrankedMemRefType::get(memref.getType().getElementType(),
958  memref.getType().getMemorySpace()),
959  memref);
960  SmallVector<OpFoldResult> mixedSizes =
961  memref::getMixedSizes(rewriter, loc, memref);
962  SmallVector<Value> sizes =
963  getValueOrCreateConstantIndexOp(rewriter, loc, mixedSizes);
964 
965  auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter);
966  Value desc = rewriter.create<nvgpu::TmaCreateDescriptorOp>(
967  loc,
969  rewriter.getContext(),
970  MemRefType::Builder(memref.getType())
971  .setMemorySpace(sharedMemorySpace),
972  TensorMapSwizzleKind::SWIZZLE_NONE,
973  TensorMapL2PromoKind::L2PROMO_NONE, TensorMapOOBKind::OOB_ZERO,
974  TensorMapInterleaveKind::INTERLEAVE_NONE),
975  unrankedMemRef, sizes);
976  return cast<TypedValue<nvgpu::TensorMapDescriptorType>>(desc);
977 }
978 
981  TypedValue<MemRefType> sharedMemref,
983  SmallVectorImpl<Operation *> &loadOps) {
984  MLIRContext *ctx = rewriter.getContext();
985  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
986  Operation *loadOp = rewriter.create<nvgpu::TmaAsyncLoadOp>(
987  loc, sharedMemref, barrier, globalDesc, ValueRange{zero, zero}, zero,
988  Value(), Value());
989  loadOps.push_back(loadOp);
990  auto mixedSizes = memref::getMixedSizes(rewriter, loc, sharedMemref);
991  SmallVector<AffineExpr> symbols(mixedSizes.size());
992  bindSymbolsList(ctx, llvm::MutableArrayRef{symbols});
993  AffineExpr prodExprInBytes =
994  computeProduct(ctx, symbols) *
995  (sharedMemref.getType().getElementTypeBitWidth() / 8);
996  auto res = affine::makeComposedFoldedAffineApply(rewriter, loc,
997  prodExprInBytes, mixedSizes);
998  return res;
999 }
1000 
1003  ArrayRef<OpFoldResult> mixedSizes) {
1004  assert(!mixedSizes.empty() && "expecte non-empty sizes");
1005  MLIRContext *ctx = rewriter.getContext();
1006  SmallVector<AffineExpr> symbols(mixedSizes.size());
1007  bindSymbolsList(ctx, llvm::MutableArrayRef{symbols});
1008  AffineExpr sumExpr = computeSum(ctx, symbols);
1009  OpFoldResult size =
1010  affine::makeComposedFoldedAffineApply(rewriter, loc, sumExpr, mixedSizes);
1011  Value sizeVal = getValueOrCreateConstantIndexOp(rewriter, loc, size);
1012  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1013  rewriter.create<nvgpu::MBarrierArriveExpectTxOp>(loc, barrier, sizeVal, zero,
1014  Value());
1015 }
1016 
1019  Type i1 = rewriter.getI1Type();
1020  Value parity = rewriter.create<LLVM::ConstantOp>(loc, i1, 0);
1021  // 10M is an arbitrary, not too small or too big number to specify the number
1022  // of ticks before retry.
1023  // TODO: hoist this in a default dialect constant.
1024  Value ticksBeforeRetry =
1025  rewriter.create<arith::ConstantIndexOp>(loc, 10000000);
1026  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1027  rewriter.create<nvgpu::MBarrierTryWaitParityOp>(loc, barrier, parity,
1028  ticksBeforeRetry, zero);
1029 }
1030 
1031 //===----------------------------------------------------------------------===//
1032 // RewriteCopyAsTmaOp
1033 //===----------------------------------------------------------------------===//
1034 
1035 /// Helper to create the tma operations corresponding to `linalg::CopyOp`.
1036 struct CopyBuilder : public HopperBuilder {
1038  : HopperBuilder(rewriter, loc) {}
1039 
1041 };
1042 
1043 SmallVector<Operation *> CopyBuilder::rewrite(ArrayRef<Operation *> copyOps) {
1044  MLIRContext *ctx = rewriter.getContext();
1045  if (copyOps.empty())
1046  return SmallVector<Operation *>();
1047 
1048  auto launchOp = copyOps.front()->getParentOfType<gpu::LaunchOp>();
1049  assert(launchOp && "expected launch op");
1050 
1051  // 1. Init a barrier object in shared memory.
1052  OpBuilder::InsertionGuard g(rewriter);
1053  rewriter.setInsertionPoint(copyOps.front());
1054  AffineExpr bx, by, bz;
1055  bindSymbols(ctx, bx, by, bz);
1056  AffineExpr prod = computeProduct(ctx, ArrayRef<AffineExpr>{bx, by, bz});
1058  rewriter, loc, prod,
1059  ArrayRef<OpFoldResult>{launchOp.getBlockSizeX(), launchOp.getBlockSizeY(),
1060  launchOp.getBlockSizeZ()});
1061 
1063  buildAndInitBarrierInSharedMemory(numThreads);
1064 
1067  for (Operation *op : copyOps) {
1068  auto copyOp = cast<linalg::CopyOp>(op);
1069  auto inMemRef =
1070  cast<TypedValue<MemRefType>>(copyOp.getDpsInputOperand(0)->get());
1071  assert(inMemRef.getType().getRank() == 2 &&
1072  "expected in to be a 2D memref");
1073 
1074  // 2. Build global memory descriptor.
1076  buildGlobalMemRefDescriptor(inMemRef, launchOp);
1077  globalDescs.push_back(globalDesc);
1078 
1079  // 3. Shared memory and descriptor for the tmp array.
1080  auto shmem =
1081  cast<TypedValue<MemRefType>>(copyOp.getDpsInitOperand(0)->get());
1082  shmems.push_back(shmem);
1083  }
1084 
1085  // 4. Load in from global memory to shared memory using tma.
1086  OpBuilder::InsertionGuard g2(rewriter);
1087  rewriter.setInsertionPoint(copyOps.front());
1088  SmallVector<Operation *> results =
1089  buildPredicateLoadsOnThread0(globalDescs, shmems, barrier);
1090 
1091  // 5. Spin-loop until data is ready.
1092  buildTryWaitParity(barrier);
1093 
1094  // 6. Erase the ops that have now been rewritten.
1095  for (Operation *op : copyOps)
1096  rewriter.eraseOp(op);
1097 
1098  return results;
1099 }
1100 
1102 transform::RewriteCopyAsTmaOp::apply(transform::TransformRewriter &rewriter,
1103  transform::TransformResults &results,
1104  transform::TransformState &state) {
1105  auto payloadOps = state.getPayloadOps(getTarget());
1106  gpu::LaunchOp commonLaunchOp;
1107  Operation *firstOp, *failingOp;
1108  if (llvm::any_of(payloadOps, [&](Operation *op) {
1109  if (!commonLaunchOp) {
1110  commonLaunchOp = op->getParentOfType<gpu::LaunchOp>();
1111  firstOp = op;
1112  }
1113  auto fail = !op->getParentOfType<gpu::LaunchOp>() ||
1114  commonLaunchOp != op->getParentOfType<gpu::LaunchOp>() ||
1115  !isa<linalg::CopyOp>(op);
1116  if (fail)
1117  failingOp = op;
1118  return fail;
1119  })) {
1121  emitSilenceableError()
1122  << "target ops must be linalg::CopyOp nested under a common "
1123  "gpu.LaunchOp to be rewritten because the tma descriptors need to "
1124  "be created on the host.\nBut got: "
1125  << *firstOp << "\nand " << *failingOp;
1126  return diag;
1127  }
1128 
1129  // TODO: more robust detection of copy, with transposes etc.
1130  CopyBuilder(rewriter, getLoc()).rewrite(llvm::to_vector(payloadOps));
1131 
1133 }
1134 
1135 //===----------------------------------------------------------------------===//
1136 // Transform op registration
1137 //===----------------------------------------------------------------------===//
1138 
1139 namespace {
1140 class NVGPUTransformDialectExtension
1142  NVGPUTransformDialectExtension> {
1143 public:
1144  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NVGPUTransformDialectExtension)
1145 
1146  NVGPUTransformDialectExtension() {
1147  declareGeneratedDialect<arith::ArithDialect>();
1148  declareGeneratedDialect<affine::AffineDialect>();
1149  declareGeneratedDialect<nvgpu::NVGPUDialect>();
1150  declareGeneratedDialect<NVVM::NVVMDialect>();
1151  declareGeneratedDialect<vector::VectorDialect>();
1152  registerTransformOps<
1153 #define GET_OP_LIST
1154 #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
1155  >();
1156  }
1157 };
1158 } // namespace
1159 
1160 #define GET_OP_CLASSES
1161 #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
1162 
1164  registry.addExtensions<NVGPUTransformDialectExtension>();
1165 }
static constexpr int64_t kSharedMemorySpace
static std::string diag(const llvm::Value &value)
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
Definition: NVGPUDialect.h:28
static Attribute getSharedAddressSpaceAttribute(OpBuilder &b)
static bool hasDefaultMemorySpace(BaseMemRefType type)
Returns true if the given type has the default memory space.
static LogicalResult collectStage0PipeliningOps(scf::ForOp forOp, llvm::SmallPtrSet< Operation *, 16 > &ops)
Populate ops with the set of operations that belong to the stage 0 of the pipelined version of the gi...
static Operation * replaceOpWithPredicatedOp(RewriterBase &rewriter, Operation *op, Value predicate)
Hook for the loop pipeliner.
static bool isStoreToShared(Operation *op, Value v)
Returns true if the operation is storing the given value into shared memory.
static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn, ReduceFn reduceFn)
Helper functions to create customizable load and stores operations.
static std::tuple< DiagnosedSilenceableFailure, scf::ForOp > pipelineForSharedCopies(RewriterBase &rewriter, scf::ForOp forOp, int64_t depth, bool epiloguePeeling)
Applies loop pipelining with the given depth to the given loop so that copies into the shared memory ...
static bool hasSharedMemorySpace(BaseMemRefType type)
Returns true if the given type has the shared (workgroup) memory space.
static bool isLoadFromGlobalStoredToShared(Operation *op)
Returns true if the operation is a load from the default memory space the result of which is only sto...
static std::tuple< SmallVector< int64_t >, SmallVector< int64_t >, SmallVector< int64_t > > makeVectorShapes(ArrayRef< int64_t > lhs, ArrayRef< int64_t > rhs, ArrayRef< int64_t > res)
static void setAsyncWaitGroupsInFlight(OpBuilder &builder, Operation *op, scf::PipeliningOption::PipelinerPart part, unsigned iteration, unsigned depth)
Hook for the loop pipeliner that sets the "num groups in flight" attribute of async wait operations c...
static void getPipelineStages(scf::ForOp forOp, std::vector< std::pair< Operation *, unsigned >> &opsWithPipelineStages, unsigned depth, llvm::SmallPtrSetImpl< Operation * > &stage0Ops)
Hook for the loop pipeliner that populates ops with the stage information as follows:
static Value getValueLoadedFromGlobal(Operation *op)
Returns the value produced by a load from the default memory space.
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< VectorShape > vectorShape(Type type)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:274
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:917
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:149
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
FloatType getF32Type()
Definition: Builders.cpp:87
FloatType getF16Type()
Definition: Builders.cpp:83
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI1Type()
Definition: Builders.cpp:97
IndexType getIndexType()
Definition: Builders.cpp:95
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:213
Builder & setMemorySpace(Attribute newMemorySpace)
Definition: BuiltinTypes.h:239
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:357
This class helps build Operations.
Definition: Builders.h:216
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:407
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Type conversion class.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:59
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:66
bool isF16() const
Definition: Types.cpp:57
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
user_range getUsers() const
Definition: Value.h:228
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:215
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void push_back(Operation *op)
Appends an element to the list.
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
@ kGlobalMemorySpace
Global memory space identifier.
Definition: NVVMDialect.h:36
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1194
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)
Return the memref type that can be used to represent an mbarrier object.
void registerTransformDialectExtension(DialectRegistry &registry)
void createAsyncGroups(RewriterBase &rewriter, Operation *op, bool bypassL1)
Convert global->shared vector transfers to async device copies.
FailureOr< ForOp > pipelineForLoop(RewriterBase &rewriter, ForOp forOp, const PipeliningOption &options, bool *modifiedIR=nullptr)
Generate a pipelined version of the scf.for loop based on the schedule given as option.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
Definition: IndexingUtils.h:47
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
void getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
void populateNVGPUToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:362
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:617
int64_t computeSum(ArrayRef< int64_t > basis)
Self-explicit.
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
Definition: AffineExpr.h:367
Helper to create the tma operations corresponding to linalg::CopyOp.
SmallVector< Operation * > rewrite(ArrayRef< Operation * > copyOps)
CopyBuilder(RewriterBase &rewriter, Location loc)
Helper to create the base Hopper-specific operations that are reused in various other places.
OpFoldResult buildTmaAsyncLoad(TypedValue< nvgpu::TensorMapDescriptorType > globalDesc, TypedValue< MemRefType > sharedMemref, TypedValue< nvgpu::MBarrierGroupType > barrier, SmallVectorImpl< Operation * > &loadOps)
Build a tma load from global memory to shared memory using barrier to synchronize.
TypedValue< nvgpu::MBarrierGroupType > buildAndInitBarrierInSharedMemory(OpFoldResult numThreads)
void buildTryWaitParity(TypedValue< nvgpu::MBarrierGroupType > barrier)
RewriterBase & rewriter
TypedValue< nvgpu::TensorMapDescriptorType > buildGlobalMemRefDescriptor(TypedValue< MemRefType > memref, gpu::LaunchOp launchOp)
Create tma descriptor op to initiate transfer from global to shared memory.
SmallVector< Operation * > buildPredicateLoadsOnThread0(ArrayRef< TypedValue< nvgpu::TensorMapDescriptorType >> globalDescriptors, ArrayRef< TypedValue< MemRefType >> sharedMemBuffers, TypedValue< nvgpu::MBarrierGroupType > barrier)
If threadIdx.x == 0 does TMA request + wait, else just wait.
void buildBarrierArriveTx(TypedValue< nvgpu::MBarrierGroupType > barrier, ArrayRef< OpFoldResult > sizes)
HopperBuilder(RewriterBase &rewriter, Location loc)
Helper struct to provide a simple mapping from matmul operations to the corresponding mma....
std::function< SmallVector< RowColIndexing >(MLIRContext *)> IndexCalculator
MmaSyncBuilder(OpBuilder &b, Location loc, OpFoldResult laneId)
FailureOr< Operation * > buildMmaSync(LinalgOp linalgOp)
Create the mma.sync operation corresponding to linalgOp along with all the supporting load/store and ...
Helper struct to encode a pair of row/column indexings in the form of affine expressions.
AffineExpr col() const
RowColIndexing(AffineExpr row, AffineExpr col)
void print(llvm::raw_ostream &os) const
AffineExpr row() const
Options to dictate how loops should be pipelined.
Definition: Transforms.h:123