MLIR  20.0.0git
MmaSyncTF32Transform.cpp
Go to the documentation of this file.
1 //===- OptimizeSharedMemory.cpp - MLIR NVGPU pass implementation ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements transforms to enable 1xtf32 and 3xtf32 nvgpu.mma sync
10 // operations on f32 input datatype
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/Support/MathExtras.h"
24 
25 using namespace mlir;
26 using namespace mlir::nvgpu;
27 
28 namespace {
29 
30 struct MmaSyncF32ToTF32Pattern : public OpRewritePattern<nvgpu::MmaSyncOp> {
31 
33 
34  MmaSyncF32ToTF32Pattern(MLIRContext *context,
35  nvgpu::MmaSyncF32Lowering precision)
36  : OpRewritePattern<nvgpu::MmaSyncOp>(context, /*benifit*/ 1),
37  precision(precision) {}
38 
39  LogicalResult matchAndRewrite(nvgpu::MmaSyncOp op,
40  PatternRewriter &rewriter) const override {
41  Location location = op->getLoc();
42 
43  if (op->hasAttr(op.getTf32EnabledAttrName()) ||
44  !cast<VectorType>(op.getMatrixA().getType()).getElementType().isF32())
45  return failure();
46 
47  if (precision == MmaSyncF32Lowering::Unkown)
48  return emitError(location, "MmaSync F32-to-TF32 cannot be lowered with "
49  "unknown precision level");
50 
51  if (precision == MmaSyncF32Lowering::TF32x3)
52  return emitError(location, "TF32x3 is not supported at the moment "
53  "for nvgpu.mma.sync on f32 datatype");
54 
55  if (precision == MmaSyncF32Lowering::TF32) {
56  rewriter.modifyOpInPlace(
57  op, [&]() { op.setTf32EnabledAttr(rewriter.getUnitAttr()); });
58  }
59 
60  return success();
61  }
62 
63 private:
64  /// Precision for F32 Tensor Cores (TF32 or TF32x3)
65  nvgpu::MmaSyncF32Lowering precision;
66 };
67 
68 } // namespace
69 
71  RewritePatternSet &patterns, nvgpu::MmaSyncF32Lowering precision) {
72 
73  patterns.add<MmaSyncF32ToTF32Pattern>(patterns.getContext(), precision);
74 }
UnitAttr getUnitAttr()
Definition: Builders.cpp:138
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
MmaSyncF32Lowering
Rewrites patterns.
Definition: Transforms.h:57
void populateMmaSyncF32ToTF32Patterns(RewritePatternSet &patterns, nvgpu::MmaSyncF32Lowering precision=nvgpu::MmaSyncF32Lowering::TF32)
Collect patterns to convert mma.sync on f32 input and rewrite to use tensor cores with user provided ...
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358