MLIR  15.0.0git
CoversionUtils.h
Go to the documentation of this file.
1 //===- ConversionUtils.h - Helper functions for tosa conversion -*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Utility functions for TOSA lowering
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef DIALECT_TOSA_UTILS_COVERSION_UTILS_H_
14 #define DIALECT_TOSA_UTILS_COVERSION_UTILS_H_
15 
19 #include "mlir/IR/PatternMatch.h"
20 
21 namespace mlir {
22 namespace tosa {
23 
24 // Creates a SmallVector of Stringrefs for N parallel loops
25 SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops);
26 
27 // Takes a vector of values and condenses them to a vector with no gaps.
28 SmallVector<Value> condenseValues(const SmallVector<Value> &values);
29 
30 // Takes the parameters for a clamp and turns it into a series of ops.
31 template <typename T, typename P>
32 arith::SelectOp clampHelper(Location loc, Value arg, arith::ConstantOp min,
33  arith::ConstantOp max, P pred,
34  OpBuilder &rewriter) {
35  auto smallerThanMin = rewriter.create<T>(loc, pred, arg, min);
36  auto minOrArg =
37  rewriter.create<arith::SelectOp>(loc, smallerThanMin, min, arg);
38  auto largerThanMax = rewriter.create<T>(loc, pred, max, arg);
39  return rewriter.create<arith::SelectOp>(loc, largerThanMax, max, minOrArg);
40 }
41 
42 // Returns the values in an attribute as an array of values.
43 template <typename T>
44 void getValuesFromIntArrayAttribute(ArrayAttr attr,
45  SmallVector<T> &arrayValues) {
46  for (Attribute val : attr.getValue()) {
47  arrayValues.push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
48  }
49 }
50 
51 // Checks for a dynamic batch dim in any of the passed parameters of an op.
52 // The batch dimention must be #0 and the rest of the dimensions must be static.
53 template <typename Op>
55  Op op,
56  ArrayRef<Value> params) {
57  SmallVector<ShapedType> dynTypes;
58  SmallVector<Value> dynamicDims;
59  for (const Value &param : params) {
60  auto paramTy = param.getType().cast<ShapedType>();
61  if (!paramTy.hasStaticShape())
62  dynTypes.push_back(paramTy);
63  }
64 
65  if (dynTypes.empty())
66  return dynamicDims;
67 
68  for (const ShapedType &dynTy : dynTypes) {
69  if (llvm::any_of(dynTy.getShape().drop_front(), ShapedType::isDynamic)) {
70  (void)rewriter.notifyMatchFailure(
71  op, "input can only be dynamic for batch size");
72  return llvm::None;
73  }
74  }
75 
76  dynamicDims.push_back(
77  rewriter.create<tensor::DimOp>(op->getLoc(), params[0], 0));
78  return dynamicDims;
79 }
80 
81 } // namespace tosa
82 } // namespace mlir
83 
84 #endif // DIALECT_TOSA_UTILS_COVERSION_UTILS_H_
Include the generated interface declarations.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:600
void getValuesFromIntArrayAttribute(ArrayAttr attr, SmallVector< T > &arrayValues)
SmallVector< Value > condenseValues(const SmallVector< Value > &values)
SmallVector< StringRef > getNParallelLoopsAttrs(unsigned nParallelLoops)
Optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:380
Attributes are known-constant values of operations.
Definition: Attributes.h:24
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:108
arith::SelectOp clampHelper(Location loc, Value arg, arith::ConstantOp min, arith::ConstantOp max, P pred, OpBuilder &rewriter)
This provides public APIs that all operations should have.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure...
Definition: PatternMatch.h:512
This class helps build Operations.
Definition: Builders.h:177
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)