MLIR 22.0.0git
TosaArithConstantToConst.cpp
Go to the documentation of this file.
1//===- TosaArithConstantToConst.cpp ---------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass that converts tensor-valued arith.constant ops
10// into tosa.const so that TOSA pipelines operate on a uniform constant form.
11//
12//===----------------------------------------------------------------------===//
13
15
24
25namespace mlir {
26namespace tosa {
27#define GEN_PASS_DEF_TOSAARITHCONSTANTTOTOSACONSTPASS
28#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
29} // namespace tosa
30} // namespace mlir
31
32using namespace mlir;
33using namespace mlir::tosa;
35namespace {
37// NOTE: TOSA pipelines already lower their constants through shared Arith
38// folding passes, so tensor literals often come back as `arith.constant` even
39// after the IR is otherwise TOSA-only. Keep this normalization with the rest of
40// the TOSA transforms so any client can re-establish a canonical `tosa.const`
41// representation without needing a full Arith->TOSA conversion library.
43/// Returns true when `elementType` is natively representable by tosa.const.
44static bool isSupportedElementType(Type elementType) {
45 if (isa<FloatType>(elementType))
46 return true;
47
48 if (auto intType = dyn_cast<IntegerType>(elementType))
49 return intType.isSignless() || intType.isUnsigned();
50
51 if (isa<quant::QuantizedType>(elementType))
52 return true;
53
54 if (isa<tosa::mxint8Type>(elementType))
55 return true;
56
57 return false;
58}
59
60class ArithConstantToTosaConst : public OpRewritePattern<arith::ConstantOp> {
61public:
63
64 LogicalResult matchAndRewrite(arith::ConstantOp constOp,
65 PatternRewriter &rewriter) const override {
66 // TOSA constant verification requires a ranked, statically shaped tensor.
67 auto resultType = dyn_cast<RankedTensorType>(constOp.getResult().getType());
68 if (!resultType || !resultType.hasStaticShape())
69 return failure();
71 if (!isSupportedElementType(resultType.getElementType()))
72 return failure();
73
74 Attribute attr = constOp.getValueAttr();
75 auto elementsAttr = dyn_cast<ElementsAttr>(attr);
76 if (!elementsAttr)
77 return failure();
78
79 auto attrType = dyn_cast<RankedTensorType>(elementsAttr.getType());
80 if (!attrType || !attrType.hasStaticShape())
81 return failure();
82 if (attrType != resultType)
83 return failure();
84
85 auto newConst = tosa::ConstOp::create(rewriter, constOp.getLoc(),
86 resultType, elementsAttr);
87 rewriter.replaceOp(constOp, newConst.getResult());
88 return success();
89 }
90};
91
92struct TosaArithConstantToTosaConstPass
94 TosaArithConstantToTosaConstPass> {
95 using Base::Base;
96
97 void getDependentDialects(DialectRegistry &registry) const override {
98 registry.insert<arith::ArithDialect, tosa::TosaDialect>();
99 }
100
101 void runOnOperation() override {
102 auto *ctx = &getContext();
103 RewritePatternSet patterns(ctx);
104 patterns.add<ArithConstantToTosaConst>(ctx);
105
106 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
107 signalPassFailure();
108 }
109};
110
111} // namespace
return success()
b getContext())
Attributes are known-constant values of operations.
Definition Attributes.h:25
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...