31#define GEN_PASS_DEF_TOSANARROWI64TOI32PASS
32#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
42 ConversionPatternRewriter &rewriter,
54 const Attribute attribute = namedAttribute.getValue();
57 if (
const auto intAttr = dyn_cast<IntegerAttr>(attribute)) {
58 const std::optional<Attribute> convertedAttribute =
59 typeConverter->convertTypeAttribute(intAttr.getType(), attribute);
60 state.
addAttribute(namedAttribute.getName(), convertedAttribute.value());
64 if (
const auto typeAttr = dyn_cast<TypeAttr>(attribute)) {
65 Type type = typeAttr.getValue();
66 const std::optional<Attribute> convertedAttribute =
67 typeConverter->convertTypeAttribute(type, attribute);
68 if (!convertedAttribute)
69 return rewriter.notifyMatchFailure(op,
70 "Failed to convert type attribute.");
71 state.
addAttribute(namedAttribute.getName(), convertedAttribute.value());
75 if (
const auto denseElementsAttr = dyn_cast<DenseElementsAttr>(attribute)) {
76 const Type type = denseElementsAttr.getType();
77 const std::optional<Attribute> convertedAttribute =
78 typeConverter->convertTypeAttribute(type, denseElementsAttr);
79 if (!convertedAttribute)
80 return rewriter.notifyMatchFailure(
81 op,
"Failed to convert dense elements attribute.");
82 state.
addAttribute(namedAttribute.getName(), convertedAttribute.value());
91 rewriter.inlineRegionBefore(region, *newRegion, newRegion->
begin());
92 if (
failed(rewriter.convertRegionTypes(newRegion, *typeConverter)))
96 Operation *newOp = rewriter.create(state);
107 ConvertGenericOp(TypeConverter &typeConverter, MLIRContext *context)
108 : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) {}
111 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
112 ConversionPatternRewriter &rewriter)
const final {
113 if (!isa<tosa::TosaOp>(op))
114 return rewriter.notifyMatchFailure(
116 "Support for operations other than TOSA has not been implemented.");
118 return convertGenericOp(op, operands, rewriter, typeConverter);
126class ConvertArgMaxOpWithBoundsChecking
127 :
public OpConversionPattern<tosa::ArgMaxOp> {
128 using OpConversionPattern::OpConversionPattern;
131 matchAndRewrite(tosa::ArgMaxOp op, OpAdaptor adaptor,
132 ConversionPatternRewriter &rewriter)
const final {
134 const int32_t axis = op.getAxis();
135 const auto inputType = dyn_cast<ShapedType>(adaptor.getInput().getType());
136 if (!inputType || !inputType.isStaticDim(axis))
137 return rewriter.notifyMatchFailure(
138 op,
"Requires a static axis dimension for bounds checking.");
139 const int64_t axisDim = inputType.getDimSize(axis);
140 if (axisDim >= std::numeric_limits<int32_t>::max())
141 return rewriter.notifyMatchFailure(
142 op,
"Axis dimension is too large to narrow safely.");
144 const Type resultType = op.getOutput().getType();
145 const Type newResultType = typeConverter->convertType(resultType);
146 rewriter.replaceOpWithNewOp<tosa::ArgMaxOp>(op, newResultType,
147 adaptor.getInput(), axis);
152class ConvertCastOpWithBoundsChecking
153 :
public OpConversionPattern<tosa::CastOp> {
154 using OpConversionPattern::OpConversionPattern;
157 matchAndRewrite(tosa::CastOp op, OpAdaptor adaptor,
158 ConversionPatternRewriter &rewriter)
const final {
159 const auto inputType = dyn_cast<ShapedType>(adaptor.getInput().getType());
161 if (!inputType || !resultType)
164 const auto elementInputIntType =
165 dyn_cast<IntegerType>(inputType.getElementType());
166 const auto elementResultIntType =
167 dyn_cast<IntegerType>(resultType.getElementType());
168 if (elementInputIntType && elementResultIntType &&
169 elementInputIntType.getWidth() > elementResultIntType.getWidth())
170 return rewriter.notifyMatchFailure(
171 op,
"Narrowing cast may lead to data loss.");
173 rewriter.replaceOpWithNewOp<tosa::CastOp>(
174 op, typeConverter->convertType(resultType), adaptor.getInput());
179class ConvertClampOpWithBoundsChecking
180 :
public OpConversionPattern<tosa::ClampOp> {
181 using OpConversionPattern::OpConversionPattern;
184 matchAndRewrite(tosa::ClampOp op, OpAdaptor adaptor,
185 ConversionPatternRewriter &rewriter)
const final {
186 const auto minAttr = dyn_cast<IntegerAttr>(op.getMinValAttr());
187 const auto maxAttr = dyn_cast<IntegerAttr>(op.getMaxValAttr());
188 if (!minAttr || !maxAttr)
191 const int64_t
min = minAttr.getInt();
192 const int64_t
max = maxAttr.getInt();
194 if (
min < std::numeric_limits<int32_t>::min() ||
195 max > std::numeric_limits<int32_t>::max())
196 return rewriter.notifyMatchFailure(
197 op,
"Clamp bounds exceed int32 range. Narrowing cast may lead to "
200 const Type resultType = op.getOutput().getType();
201 const Type newResultType = typeConverter->convertType(resultType);
203 const IntegerType int32Type = IntegerType::get(rewriter.getContext(), 32);
204 const IntegerAttr newMinAttr =
205 rewriter.getIntegerAttr(int32Type,
static_cast<int32_t
>(
min));
206 const IntegerAttr newMaxAttr =
207 rewriter.getIntegerAttr(int32Type,
static_cast<int32_t
>(
max));
208 rewriter.replaceOpWithNewOp<tosa::ClampOp>(op, newResultType,
209 adaptor.getInput(), newMinAttr,
210 newMaxAttr, op.getNanModeAttr());
215template <
typename OpTy>
216class ConvertTypedOp :
public OpConversionPattern<OpTy> {
217 using OpConversionPattern<OpTy>::OpConversionPattern;
220 matchAndRewrite(OpTy op,
typename OpTy::Adaptor adaptor,
221 ConversionPatternRewriter &rewriter)
const final {
222 return convertGenericOp(op, adaptor.getOperands(), rewriter,
223 this->getTypeConverter());
227struct TosaNarrowI64ToI32
230 explicit TosaNarrowI64ToI32() =
default;
231 explicit TosaNarrowI64ToI32(
const TosaNarrowI64ToI32PassOptions &
options)
232 : TosaNarrowI64ToI32() {
233 this->aggressiveRewrite =
options.aggressiveRewrite;
234 this->convertFunctionBoundaries =
options.convertFunctionBoundaries;
237 void runOnOperation()
override {
240 TypeConverter typeConverter;
241 typeConverter.addConversion([](Type type) -> Type {
return type; });
242 typeConverter.addConversion([](IntegerType type) -> Type {
243 if (!type.isInteger(64))
245 return IntegerType::get(type.getContext(), 32);
247 typeConverter.addConversion(
248 [&typeConverter](RankedTensorType type) -> Type {
249 const Type elementType = type.getElementType();
252 return RankedTensorType::get(type.getShape(),
253 typeConverter.convertType(elementType));
256 const auto materializeCast = [](OpBuilder &builder, Type resultType,
258 if (inputs.size() != 1)
260 return tosa::CastOp::create(builder, loc, resultType, inputs.front());
262 typeConverter.addSourceMaterialization(materializeCast);
263 typeConverter.addTargetMaterialization(materializeCast);
265 typeConverter.addTypeAttributeConversion(
266 [](IntegerType type, IntegerAttr attribute) -> Attribute {
267 const APInt value = attribute.getValue().truncSSat(32);
268 return IntegerAttr::get(IntegerType::get(type.getContext(), 32),
271 typeConverter.addTypeAttributeConversion(
272 [&typeConverter](ShapedType type,
273 DenseIntElementsAttr attr) -> Attribute {
274 const ShapedType newType =
275 cast<ShapedType>(typeConverter.convertType(type));
276 const auto oldElementType = cast<IntegerType>(type.getElementType());
277 const auto newElementType =
278 cast<IntegerType>(newType.getElementType());
279 if (oldElementType.getWidth() == newElementType.getWidth())
282 DenseElementsAttr mapped =
283 attr.
mapValues(newElementType, [&](
const APInt &v) {
284 return v.truncSSat(newElementType.getWidth());
289 ConversionTarget
target(*context);
290 target.addDynamicallyLegalDialect<tosa::TosaDialect>(
291 [&typeConverter](Operation *op) {
295 if (convertFunctionBoundaries) {
296 target.addDynamicallyLegalOp<func::FuncOp>(
297 [&typeConverter](func::FuncOp op) {
298 return typeConverter.isSignatureLegal(op.getFunctionType()) &&
299 typeConverter.isLegal(&op.getBody());
301 target.addDynamicallyLegalOp<func::ReturnOp>([](func::ReturnOp op) {
302 const FunctionType funcType =
307 target.addDynamicallyLegalOp<func::FuncOp>(
308 [](func::FuncOp op) {
return true; });
309 target.addDynamicallyLegalOp<func::ReturnOp>(
310 [](func::ReturnOp op) {
return true; });
313 RewritePatternSet
patterns(context);
314 if (convertFunctionBoundaries) {
315 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
319 if (aggressiveRewrite) {
320 patterns.add<ConvertGenericOp>(typeConverter, context);
323 patterns.add<ConvertArgMaxOpWithBoundsChecking>(typeConverter, context);
325 patterns.add<ConvertClampOpWithBoundsChecking>(typeConverter, context);
327 patterns.add<ConvertTypedOp<tosa::ConcatOp>>(typeConverter, context);
328 patterns.add<ConvertTypedOp<tosa::PadOp>>(typeConverter, context);
329 patterns.add<ConvertTypedOp<tosa::ReshapeOp>>(typeConverter, context);
330 patterns.add<ConvertTypedOp<tosa::ReverseOp>>(typeConverter, context);
331 patterns.add<ConvertTypedOp<tosa::SliceOp>>(typeConverter, context);
332 patterns.add<ConvertTypedOp<tosa::TileOp>>(typeConverter, context);
333 patterns.add<ConvertTypedOp<tosa::TransposeOp>>(typeConverter, context);
334 patterns.add<ConvertTypedOp<tosa::IdentityOp>>(typeConverter, context);
336 patterns.add<ConvertCastOpWithBoundsChecking>(typeConverter, context);
338 patterns.add<ConvertTypedOp<tosa::IfOp>>(typeConverter, context);
339 patterns.add<ConvertTypedOp<tosa::WhileOp>>(typeConverter, context);
343 applyFullConversion(getOperation(),
target, std::move(
patterns))))
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
DenseElementsAttr mapValues(Type newElementType, function_ref< APInt(const APInt &)> mapping) const
Generates a new DenseElementsAttr by mapping each value attribute, and constructing the DenseElements...
NamedAttribute represents a combination of a name and an Attribute value.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
result_type_range getResultTypes()
SuccessorRange getSuccessors()
result_range getResults()
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
Type getType() const
Return the type of this value.
Include the generated interface declarations.
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
const FrozenRewritePatternSet & patterns
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
Region * addRegion()
Create a region that should be attached to the operation.