|
MLIR 22.0.0git
|
Helper to create the tma operations corresponding to linalg::CopyOp. More...
Public Member Functions | |
| CopyBuilder (RewriterBase &rewriter, Location loc) | |
| SmallVector< Operation * > | rewrite (ArrayRef< Operation * > copyOps) |
| Public Member Functions inherited from HopperBuilder | |
| HopperBuilder (RewriterBase &rewriter, Location loc) | |
| TypedValue< MBarrierGroupType > | buildAndInitBarrierInSharedMemory (OpFoldResult numThreads) |
| TypedValue< TensorMapDescriptorType > | buildGlobalMemRefDescriptor (TypedValue< MemRefType > memref, gpu::LaunchOp launchOp) |
| Create tma descriptor op to initiate transfer from global to shared memory. | |
| OpFoldResult | buildTmaAsyncLoad (TypedValue< TensorMapDescriptorType > globalDesc, TypedValue< MemRefType > sharedMemref, TypedValue< MBarrierGroupType > barrier, SmallVectorImpl< Operation * > &loadOps) |
| Build a tma load from global memory to shared memory using barrier to synchronize. | |
| void | buildBarrierArriveTx (TypedValue< MBarrierGroupType > barrier, ArrayRef< OpFoldResult > sizes) |
| SmallVector< Operation * > | buildPredicateLoadsOnThread0 (ArrayRef< TypedValue< TensorMapDescriptorType > > globalDescriptors, ArrayRef< TypedValue< MemRefType > > sharedMemBuffers, TypedValue< MBarrierGroupType > barrier) |
| If threadIdx.x == 0 does TMA request + wait, else just wait. | |
| void | buildTryWaitParity (TypedValue< MBarrierGroupType > barrier) |
Additional Inherited Members | |
| Public Attributes inherited from HopperBuilder | |
| RewriterBase & | rewriter |
| Location | loc |
Helper to create the tma operations corresponding to linalg::CopyOp.
Definition at line 1027 of file NVGPUTransformOps.cpp.
|
inline |
Definition at line 1028 of file NVGPUTransformOps.cpp.
References HopperBuilder::HopperBuilder(), HopperBuilder::loc, and HopperBuilder::rewriter.
| SmallVector< Operation * > CopyBuilder::rewrite | ( | ArrayRef< Operation * > | copyOps | ) |
Definition at line 1034 of file NVGPUTransformOps.cpp.
References mlir::bindSymbols(), HopperBuilder::buildAndInitBarrierInSharedMemory(), HopperBuilder::buildGlobalMemRefDescriptor(), HopperBuilder::buildPredicateLoadsOnThread0(), HopperBuilder::buildTryWaitParity(), mlir::computeProduct(), mlir::get(), HopperBuilder::loc, mlir::affine::makeComposedFoldedAffineApply(), and HopperBuilder::rewriter.