MLIR
22.0.0git
Main Page
Related Pages
Namespaces
Namespace List
Namespace Members
All
_
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
y
Functions
_
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
y
Variables
a
c
f
h
i
k
m
n
o
p
r
s
Typedefs
a
b
c
d
e
f
g
h
i
l
m
n
o
p
q
r
s
t
u
v
w
Enumerations
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
v
w
Enumerator
a
b
c
d
e
f
g
h
i
k
m
n
o
p
r
s
t
u
v
w
Classes
Class List
Class Index
Class Hierarchy
Class Members
All
_
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
~
Functions
_
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
y
~
Variables
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
Typedefs
a
b
c
d
e
f
g
h
i
k
l
m
n
o
p
r
s
t
u
v
w
Enumerations
a
b
c
d
f
i
k
l
m
n
o
p
r
s
t
u
v
w
Enumerator
a
c
d
e
f
g
h
i
k
l
m
n
p
r
s
u
v
Related Functions
a
b
c
d
e
f
g
h
i
l
m
n
o
p
r
s
t
v
Files
File List
File Members
All
_
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
Functions
_
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
Variables
_
a
b
c
d
e
g
h
i
k
l
m
n
o
p
r
s
t
u
v
w
x
Typedefs
a
b
c
d
e
f
g
h
i
m
n
o
r
s
t
u
v
y
Enumerations
Enumerator
a
b
c
e
f
g
i
m
n
s
t
w
Macros
_
a
b
c
d
e
f
g
h
i
l
m
n
o
p
r
s
t
u
v
w
y
z
lib
Dialect
NVGPU
Transforms
MmaSyncTF32Transform.cpp
Go to the documentation of this file.
1
//===- OptimizeSharedMemory.cpp - MLIR NVGPU pass implementation ----------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
// This file implements transforms to enable 1xtf32 and 3xtf32 nvgpu.mma sync
10
// operations on f32 input datatype
11
//
12
//===----------------------------------------------------------------------===//
13
14
#include "
mlir/Dialect/NVGPU/Transforms/Transforms.h
"
15
16
#include "
mlir/Dialect/MemRef/IR/MemRef.h
"
17
#include "
mlir/Dialect/NVGPU/IR/NVGPUDialect.h
"
18
#include "
mlir/Dialect/Vector/IR/VectorOps.h
"
19
20
using namespace
mlir
;
21
using namespace
mlir::nvgpu
;
22
23
namespace
{
24
25
struct
MmaSyncF32ToTF32Pattern :
public
OpRewritePattern
<nvgpu::MmaSyncOp> {
26
27
using
OpRewritePattern<nvgpu::MmaSyncOp>::OpRewritePattern
;
28
29
MmaSyncF32ToTF32Pattern(
MLIRContext
*context,
30
nvgpu::MmaSyncF32Lowering
precision)
31
:
OpRewritePattern
<nvgpu::MmaSyncOp>(context,
/*benifit*/
1),
32
precision(precision) {}
33
34
LogicalResult matchAndRewrite(nvgpu::MmaSyncOp op,
35
PatternRewriter
&rewriter)
const override
{
36
Location
location = op->getLoc();
37
38
if
(op->hasAttr(op.getTf32EnabledAttrName()) ||
39
!cast<VectorType>(op.getMatrixA().getType()).getElementType().isF32())
40
return
failure();
41
42
if
(precision ==
MmaSyncF32Lowering::Unkown
)
43
return
emitError
(location,
"MmaSync F32-to-TF32 cannot be lowered with "
44
"unknown precision level"
);
45
46
if
(precision ==
MmaSyncF32Lowering::TF32x3
)
47
return
emitError
(location,
"TF32x3 is not supported at the moment "
48
"for nvgpu.mma.sync on f32 datatype"
);
49
50
if
(precision ==
MmaSyncF32Lowering::TF32
) {
51
rewriter.
modifyOpInPlace
(
52
op, [&]() { op.setTf32EnabledAttr(rewriter.
getUnitAttr
()); });
53
}
54
55
return
success();
56
}
57
58
private
:
59
/// Precision for F32 Tensor Cores (TF32 or TF32x3)
60
nvgpu::MmaSyncF32Lowering
precision;
61
};
62
63
}
// namespace
64
65
void
mlir::nvgpu::populateMmaSyncF32ToTF32Patterns
(
66
RewritePatternSet
&
patterns
,
nvgpu::MmaSyncF32Lowering
precision) {
67
68
patterns
.add<MmaSyncF32ToTF32Pattern>(
patterns
.getContext(), precision);
69
}
NVGPUDialect.h
VectorOps.h
mlir::Builder::getUnitAttr
UnitAttr getUnitAttr()
Definition:
Builders.cpp:93
mlir::Location
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition:
Location.h:76
mlir::MLIRContext
MLIRContext is the top-level object for a collection of MLIR operations.
Definition:
MLIRContext.h:60
mlir::PatternRewriter
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition:
PatternMatch.h:769
mlir::RewritePatternSet
Definition:
PatternMatch.h:792
mlir::RewriterBase::modifyOpInPlace
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition:
PatternMatch.h:614
MemRef.h
Transforms.h
mlir::nvgpu
Definition:
NVGPUToNVVM.h:25
mlir::nvgpu::MmaSyncF32Lowering
MmaSyncF32Lowering
Rewrites patterns.
Definition:
Transforms.h:57
mlir::nvgpu::MmaSyncF32Lowering::Unkown
@ Unkown
mlir::nvgpu::MmaSyncF32Lowering::TF32
@ TF32
mlir::nvgpu::MmaSyncF32Lowering::TF32x3
@ TF32x3
mlir::nvgpu::populateMmaSyncF32ToTF32Patterns
void populateMmaSyncF32ToTF32Patterns(RewritePatternSet &patterns, nvgpu::MmaSyncF32Lowering precision=nvgpu::MmaSyncF32Lowering::TF32)
Collect patterns to convert mma.sync on f32 input and rewrite to use tensor cores with user provided ...
Definition:
MmaSyncTF32Transform.cpp:65
mlir
Include the generated interface declarations.
Definition:
LocalAliasAnalysis.h:20
mlir::emitError
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Definition:
Diagnostics.cpp:328
mlir::patterns
const FrozenRewritePatternSet & patterns
Definition:
GreedyPatternRewriteDriver.h:283
mlir::OpRewritePattern
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition:
PatternMatch.h:314
Generated on Fri Jul 25 2025 04:35:50 for MLIR by
1.9.1