27 using OpRewritePattern<nvgpu::MmaSyncOp>::OpRewritePattern;
29 MmaSyncF32ToTF32Pattern(MLIRContext *context,
31 : OpRewritePattern<nvgpu::MmaSyncOp>(context, 1),
32 precision(precision) {}
34 LogicalResult matchAndRewrite(nvgpu::MmaSyncOp op,
35 PatternRewriter &rewriter)
const override {
36 Location location = op->getLoc();
38 if (op->hasAttr(op.getTf32EnabledAttrName()) ||
39 !cast<VectorType>(op.getMatrixA().getType()).getElementType().isF32())
42 if (precision == MmaSyncF32Lowering::Unkown)
43 return emitError(location,
"MmaSync F32-to-TF32 cannot be lowered with "
44 "unknown precision level");
46 if (precision == MmaSyncF32Lowering::TF32x3)
47 return emitError(location,
"TF32x3 is not supported at the moment "
48 "for nvgpu.mma.sync on f32 datatype");
50 if (precision == MmaSyncF32Lowering::TF32) {
52 op, [&]() { op.setTf32EnabledAttr(rewriter.
getUnitAttr()); });
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
MmaSyncF32Lowering
Rewrites patterns.
void populateMmaSyncF32ToTF32Patterns(RewritePatternSet &patterns, nvgpu::MmaSyncF32Lowering precision=nvgpu::MmaSyncF32Lowering::TF32)
Collect patterns to convert mma.sync on f32 input and rewrite to use tensor cores with user provided ...
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...