MLIR  19.0.0git
MmaSyncTF32Transform.cpp
Go to the documentation of this file.
1 //===- OptimizeSharedMemory.cpp - MLIR NVGPU pass implementation ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements transforms to enable 1xtf32 and 3xtf32 nvgpu.mma sync
10 // operations on f32 input datatype
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/Support/MathExtras.h"
25 
26 using namespace mlir;
27 using namespace mlir::nvgpu;
28 
29 namespace {
30 
31 struct MmaSyncF32ToTF32Pattern : public OpRewritePattern<nvgpu::MmaSyncOp> {
32 
34 
35  MmaSyncF32ToTF32Pattern(MLIRContext *context,
36  nvgpu::MmaSyncF32Lowering precision)
37  : OpRewritePattern<nvgpu::MmaSyncOp>(context, /*benifit*/ 1),
38  precision(precision) {}
39 
40  LogicalResult matchAndRewrite(nvgpu::MmaSyncOp op,
41  PatternRewriter &rewriter) const override {
42  Location location = op->getLoc();
43 
44  if (op->hasAttr(op.getTf32EnabledAttrName()) ||
45  !cast<VectorType>(op.getMatrixA().getType()).getElementType().isF32())
46  return failure();
47 
48  if (precision == MmaSyncF32Lowering::Unkown)
49  return emitError(location, "MmaSync F32-to-TF32 cannot be lowered with "
50  "unknown precision level");
51 
52  if (precision == MmaSyncF32Lowering::TF32x3)
53  return emitError(location, "TF32x3 is not supported at the moment "
54  "for nvgpu.mma.sync on f32 datatype");
55 
56  if (precision == MmaSyncF32Lowering::TF32) {
57  rewriter.modifyOpInPlace(
58  op, [&]() { op.setTf32EnabledAttr(rewriter.getUnitAttr()); });
59  }
60 
61  return success();
62  }
63 
64 private:
65  /// Precision for F32 Tensor Cores (TF32 or TF32x3)
66  nvgpu::MmaSyncF32Lowering precision;
67 };
68 
69 } // namespace
70 
72  RewritePatternSet &patterns, nvgpu::MmaSyncF32Lowering precision) {
73 
74  patterns.add<MmaSyncF32ToTF32Pattern>(patterns.getContext(), precision);
75 }
UnitAttr getUnitAttr()
Definition: Builders.cpp:114
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:555
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
MmaSyncF32Lowering
Rewrites patterns.
Definition: Transforms.h:58
void populateMmaSyncF32ToTF32Patterns(RewritePatternSet &patterns, nvgpu::MmaSyncF32Lowering precision=nvgpu::MmaSyncF32Lowering::TF32)
Collect patterns to convert mma.sync on f32 input and rewrite to use tensor cores with user provided ...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358