MLIR 22.0.0git
MmaSyncTF32Transform.cpp
Go to the documentation of this file.
1//===- OptimizeSharedMemory.cpp - MLIR NVGPU pass implementation ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements transforms to enable 1xtf32 and 3xtf32 nvgpu.mma sync
10// operations on f32 input datatype
11//
12//===----------------------------------------------------------------------===//
13
15
19
20using namespace mlir;
21using namespace mlir::nvgpu;
22
23namespace {
24
25struct MmaSyncF32ToTF32Pattern : public OpRewritePattern<nvgpu::MmaSyncOp> {
26
27 using OpRewritePattern<nvgpu::MmaSyncOp>::OpRewritePattern;
28
29 MmaSyncF32ToTF32Pattern(MLIRContext *context,
31 : OpRewritePattern<nvgpu::MmaSyncOp>(context, /*benifit*/ 1),
32 precision(precision) {}
33
34 LogicalResult matchAndRewrite(nvgpu::MmaSyncOp op,
35 PatternRewriter &rewriter) const override {
36 Location location = op->getLoc();
37
38 if (op->hasAttr(op.getTf32EnabledAttrName()) ||
39 !cast<VectorType>(op.getMatrixA().getType()).getElementType().isF32())
40 return failure();
41
42 if (precision == MmaSyncF32Lowering::Unkown)
43 return emitError(location, "MmaSync F32-to-TF32 cannot be lowered with "
44 "unknown precision level");
45
46 if (precision == MmaSyncF32Lowering::TF32x3)
47 return emitError(location, "TF32x3 is not supported at the moment "
48 "for nvgpu.mma.sync on f32 datatype");
49
50 if (precision == MmaSyncF32Lowering::TF32) {
51 rewriter.modifyOpInPlace(
52 op, [&]() { op.setTf32EnabledAttr(rewriter.getUnitAttr()); });
53 }
54
55 return success();
56 }
57
58private:
59 /// Precision for F32 Tensor Cores (TF32 or TF32x3)
61};
62
63} // namespace
64
67 patterns.add<MmaSyncF32ToTF32Pattern>(patterns.getContext(), precision);
68}
return success()
UnitAttr getUnitAttr()
Definition Builders.cpp:98
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
MmaSyncF32Lowering
Rewrites patterns.
Definition Transforms.h:57
void populateMmaSyncF32ToTF32Patterns(RewritePatternSet &patterns, nvgpu::MmaSyncF32Lowering precision=nvgpu::MmaSyncF32Lowering::TF32)
Collect patterns to convert mma.sync on f32 input and rewrite to use tensor cores with user provided ...
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...