MLIR
20.0.0git
lib
Dialect
NVGPU
Transforms
MmaSyncTF32Transform.cpp
Go to the documentation of this file.
1
//===- OptimizeSharedMemory.cpp - MLIR NVGPU pass implementation ----------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
// This file implements transforms to enable 1xtf32 and 3xtf32 nvgpu.mma sync
10
// operations on f32 input datatype
11
//
12
//===----------------------------------------------------------------------===//
13
14
#include "
mlir/Dialect/NVGPU/Transforms/Transforms.h
"
15
16
#include "
mlir/Dialect/Arith/IR/Arith.h
"
17
#include "
mlir/Dialect/GPU/IR/GPUDialect.h
"
18
#include "
mlir/Dialect/MemRef/IR/MemRef.h
"
19
#include "
mlir/Dialect/NVGPU/IR/NVGPUDialect.h
"
20
#include "
mlir/Dialect/Vector/IR/VectorOps.h
"
21
#include "
mlir/Interfaces/SideEffectInterfaces.h
"
22
#include "llvm/ADT/STLExtras.h"
23
#include "llvm/Support/MathExtras.h"
24
25
using namespace
mlir
;
26
using namespace
mlir::nvgpu
;
27
28
namespace
{
29
30
struct
MmaSyncF32ToTF32Pattern :
public
OpRewritePattern
<nvgpu::MmaSyncOp> {
31
32
using
OpRewritePattern<nvgpu::MmaSyncOp>::OpRewritePattern
;
33
34
MmaSyncF32ToTF32Pattern(
MLIRContext
*context,
35
nvgpu::MmaSyncF32Lowering
precision)
36
:
OpRewritePattern
<nvgpu::MmaSyncOp>(context,
/*benifit*/
1),
37
precision(precision) {}
38
39
LogicalResult matchAndRewrite(nvgpu::MmaSyncOp op,
40
PatternRewriter
&rewriter)
const override
{
41
Location
location = op->getLoc();
42
43
if
(op->hasAttr(op.getTf32EnabledAttrName()) ||
44
!cast<VectorType>(op.getMatrixA().getType()).getElementType().isF32())
45
return
failure();
46
47
if
(precision ==
MmaSyncF32Lowering::Unkown
)
48
return
emitError
(location,
"MmaSync F32-to-TF32 cannot be lowered with "
49
"unknown precision level"
);
50
51
if
(precision ==
MmaSyncF32Lowering::TF32x3
)
52
return
emitError
(location,
"TF32x3 is not supported at the moment "
53
"for nvgpu.mma.sync on f32 datatype"
);
54
55
if
(precision ==
MmaSyncF32Lowering::TF32
) {
56
rewriter.
modifyOpInPlace
(
57
op, [&]() { op.setTf32EnabledAttr(rewriter.
getUnitAttr
()); });
58
}
59
60
return
success();
61
}
62
63
private
:
64
/// Precision for F32 Tensor Cores (TF32 or TF32x3)
65
nvgpu::MmaSyncF32Lowering
precision;
66
};
67
68
}
// namespace
69
70
void
mlir::nvgpu::populateMmaSyncF32ToTF32Patterns
(
71
RewritePatternSet
&
patterns
,
nvgpu::MmaSyncF32Lowering
precision) {
72
73
patterns
.add<MmaSyncF32ToTF32Pattern>(
patterns
.getContext(), precision);
74
}
GPUDialect.h
NVGPUDialect.h
SideEffectInterfaces.h
VectorOps.h
mlir::Builder::getUnitAttr
UnitAttr getUnitAttr()
Definition:
Builders.cpp:138
mlir::Location
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition:
Location.h:66
mlir::MLIRContext
MLIRContext is the top-level object for a collection of MLIR operations.
Definition:
MLIRContext.h:60
mlir::PatternRewriter
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition:
PatternMatch.h:791
mlir::RewritePatternSet
Definition:
PatternMatch.h:814
mlir::RewriterBase::modifyOpInPlace
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition:
PatternMatch.h:636
Arith.h
MemRef.h
Transforms.h
mlir::nvgpu
Definition:
NVGPUToNVVM.h:25
mlir::nvgpu::MmaSyncF32Lowering
MmaSyncF32Lowering
Rewrites patterns.
Definition:
Transforms.h:57
mlir::nvgpu::MmaSyncF32Lowering::Unkown
@ Unkown
mlir::nvgpu::MmaSyncF32Lowering::TF32
@ TF32
mlir::nvgpu::MmaSyncF32Lowering::TF32x3
@ TF32x3
mlir::nvgpu::populateMmaSyncF32ToTF32Patterns
void populateMmaSyncF32ToTF32Patterns(RewritePatternSet &patterns, nvgpu::MmaSyncF32Lowering precision=nvgpu::MmaSyncF32Lowering::TF32)
Collect patterns to convert mma.sync on f32 input and rewrite to use tensor cores with user provided ...
Definition:
MmaSyncTF32Transform.cpp:70
mlir
Include the generated interface declarations.
Definition:
LocalAliasAnalysis.h:20
mlir::emitError
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Definition:
Diagnostics.cpp:328
mlir::patterns
const FrozenRewritePatternSet & patterns
Definition:
GreedyPatternRewriteDriver.h:233
mlir::OpRewritePattern
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition:
PatternMatch.h:358
Generated on Sat Dec 21 2024 12:33:58 for MLIR by
1.9.1