MLIR 22.0.0git
TosaConvertIntegerTypeToSignless.cpp
Go to the documentation of this file.
1//===- TosaConvertIntegerTypeToSignless.cpp
2//-------------------------------------------===//
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===-------------------------------------------------------------------------------===//
9
10// -----------
11// Motivation:
12// -----------
13
14// The TOSA specification uses a signless type system, which means that
15// information about signedness must be encapsulated by the operations
16// themselves. For example, tosa.rescale provides the attributes
17// `input_unsigned` and `output_unsigned` to indicate whether the input/output
18// should be interpreted as unsigned or signed.
19
20// The TOSA dialect, on the other hand, allows the use of signed or unsigned
21// types in addition to signless. As such, when converting from TOSA dialect to
22// other formats, we need to ensure that we conform to the TOSA specification.
23
24// ---------
25// Overview:
26// ---------
27
28// This pass converts signed or unsigned integer types to signless. It currently
29// does this greedily for all operators and can also change the signature of the
30// function. Should the signature of the entrypoint function change, it will be
31// the responsibility of the user to carry signedness information of the inputs
32// and outputs independently.
33
38
39namespace mlir {
40namespace tosa {
41
42#define GEN_PASS_DEF_TOSACONVERTINTEGERTYPETOSIGNLESS
43#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
44
45namespace {
46class ToSignlessTensorTypeConverter : public TypeConverter {
47 static Type convertType(Type type) {
48 const auto tensorType = dyn_cast<TensorType>(type);
49 if (!tensorType)
50 return type;
51
52 const auto intType = dyn_cast<IntegerType>(tensorType.getElementType());
53 if (!intType ||
54 intType.getSignedness() == IntegerType::SignednessSemantics::Signless)
55 return type;
56
57 const auto signlessType = IntegerType::get(
58 intType.getContext(), intType.getWidth(), IntegerType::Signless);
59 return tensorType.cloneWith(std::nullopt, signlessType);
60 }
61
62public:
63 explicit ToSignlessTensorTypeConverter() { addConversion(convertType); }
64};
65
66class ConvertGenericOpWithIntegerTensorType : public ConversionPattern {
67public:
68 ConvertGenericOpWithIntegerTensorType(TypeConverter &typeConverter,
69 MLIRContext *context)
70 : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) {}
71
72 LogicalResult
73 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
74 ConversionPatternRewriter &rewriter) const final {
75 // Typically TOSA operators have a single result, but some have an
76 // arbitrary number. 4 seems like a good balance as an optimization
77 // hint for storing result types.
78 constexpr unsigned int numResults = 4;
79
80 // Convert integer types to signless
82 if (failed(typeConverter->convertTypes(op->getResultTypes(), resultTypes)))
83 return failure();
84
85 // Create new op with replaced operands and results
86 auto *newOp = Operation::create(
87 op->getLoc(), op->getName(), resultTypes, operands, op->getAttrs(),
88 op->getPropertiesStorage(), op->getSuccessors(), op->getNumRegions());
89
90 // Handle regions in e.g. tosa.cond_if and tosa.while_loop
91 for (auto regions : llvm::zip(op->getRegions(), newOp->getRegions())) {
92 Region &before = std::get<0>(regions);
93 Region &parent = std::get<1>(regions);
94 rewriter.inlineRegionBefore(before, parent, parent.end());
95 if (failed(rewriter.convertRegionTypes(&parent, *typeConverter)))
96 return failure();
97 }
98
99 // Replace with rewritten op
100 rewriter.insert(newOp);
101 rewriter.replaceOp(op, newOp->getResults());
102 return success();
103 }
104};
105
106class ConvertTosaConstWithIntegerTensorType
107 : public OpConversionPattern<tosa::ConstOp> {
108 using OpConversionPattern::OpConversionPattern;
109
110 LogicalResult
111 matchAndRewrite(tosa::ConstOp op, OpAdaptor adaptor,
112 ConversionPatternRewriter &rewriter) const final {
113 const ElementsAttr oldAttr = op.getValues();
114 const auto oldTy = llvm::cast<ShapedType>(oldAttr.getType());
115 const auto newTy =
116 llvm::cast<ShapedType>(typeConverter->convertType(oldTy));
117 if (oldTy == newTy)
118 return success();
119
120 ElementsAttr newAttr = oldAttr;
121 if (auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(oldAttr)) {
122 newAttr = DenseElementsAttr::get(newTy, denseAttr.getRawData());
123 } else {
124 return rewriter.notifyMatchFailure(op, "unknown elements attribute type");
125 }
126
127 rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, newTy, newAttr);
128 return success();
129 }
130};
131
132class TosaConvertIntegerTypeToSignless
133 : public impl::TosaConvertIntegerTypeToSignlessBase<
134 TosaConvertIntegerTypeToSignless> {
135public:
136 void runOnOperation() override {
137 MLIRContext *context = &getContext();
138 ConversionTarget target(*context);
139 ToSignlessTensorTypeConverter typeConverter;
140
141 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
142 return typeConverter.isSignatureLegal(op.getFunctionType()) &&
143 typeConverter.isLegal(&op.getBody());
144 });
145 target.addDynamicallyLegalOp<tosa::ConstOp>([&](tosa::ConstOp op) {
146 return typeConverter.isLegal(op.getType()) &&
147 typeConverter.isLegal(op.getValues().getType());
148 });
149 target.markUnknownOpDynamicallyLegal([&](Operation *op) {
150 return typeConverter.isLegal(op->getOperandTypes()) &&
151 typeConverter.isLegal(op->getResultTypes());
152 });
155 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
156 patterns, typeConverter);
157 patterns.add<ConvertGenericOpWithIntegerTensorType>(typeConverter, context);
158 patterns.add<ConvertTosaConstWithIntegerTensorType>(typeConverter, context);
159
160 if (failed(
161 applyFullConversion(getOperation(), target, std::move(patterns))))
164};
166} // namespace
167
168} // namespace tosa
169} // namespace mlir
return success()
b getContext())
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition Operation.cpp:67
operand_type_range getOperandTypes()
Definition Operation.h:397
result_type_range getResultTypes()
Definition Operation.h:428
void signalPassFailure()
Signal that some invariant was broken when running.
Definition Pass.h:218
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns