MLIR  19.0.0git
ConversionUtils.cpp
Go to the documentation of this file.
1 //===- ConversionUtils.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Utility functions for TOSA lowering
10 //
11 //===----------------------------------------------------------------------===//
12 
15 
16 using namespace mlir;
17 using namespace mlir::tosa;
18 
20 mlir::tosa::getNParallelLoopsAttrs(unsigned nParallelLoops) {
21  return SmallVector<utils::IteratorType>(nParallelLoops,
22  utils::IteratorType::parallel);
23 }
24 
27  SmallVector<Value> condensedValues;
28  for (auto value : values)
29  if (value)
30  condensedValues.push_back(value);
31  return condensedValues;
32 }
33 
35  Value max, OpBuilder &rewriter) {
36  Value minValue = rewriter.create<arith::MinimumFOp>(loc, arg, max);
37  return rewriter.create<arith::MaximumFOp>(loc, minValue, min);
38 }
39 
41  OpBuilder &rewriter) {
42  auto minOrArg = rewriter.create<arith::MaxSIOp>(loc, min, arg);
43  return rewriter.create<arith::MinSIOp>(loc, max, minOrArg);
44 }
45 
46 bool mlir::tosa::validIntegerRange(IntegerType ty, int64_t value) {
47  uint64_t bitwidth = ty.getIntOrFloatBitWidth();
48  if (ty.getSignedness() == IntegerType::Unsigned) {
49  uint64_t uvalue = value;
50  APInt intMin = APInt::getMinValue(bitwidth);
51  APInt intMax = APInt::getMaxValue(bitwidth);
52  return uvalue >= intMin.getZExtValue() && uvalue <= intMax.getZExtValue();
53  }
54 
55  APInt intMin = APInt::getSignedMinValue(bitwidth);
56  APInt intMax = APInt::getSignedMaxValue(bitwidth);
57  return value >= intMin.getSExtValue() && value <= intMax.getSExtValue();
58 }
59 
60 namespace {
61 // Given two tensors of high and low ranks, derive the output shape
62 // to reshape the lower rank to.
63 // Examples:
64 // If lower=[c], higher=[a, b, c], [c] reshaped into [1, 1, c].
65 // If lower=[b, c], higher=[a, b, c], [b, c] reshaped into [1, b, c].
66 // If lower=[a], higher=[a, a], [a] reshaped into [1, a].
67 // If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a].
68 // If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1].
70 computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
71  ArrayRef<int64_t> lowerRankShape,
72  SmallVectorImpl<int64_t> &reshapeOutputShape) {
73  // Initialize new shapes with [1] * higherRank.
74  int64_t higherRank = higherRankShape.size();
75  int64_t lowerRank = lowerRankShape.size();
76 
77  reshapeOutputShape.assign(higherRank, 1);
78 
79  int64_t higherRankDim;
80  int64_t lowerRankDim;
81 
82  for (int64_t i = higherRank - 1, j = lowerRank - 1; i >= 0 && j >= 0;
83  i--, j--) {
84  higherRankDim = higherRankShape[i];
85  lowerRankDim = lowerRankShape[j];
86 
87  if (lowerRankDim == 1 && higherRankDim > 1)
88  reshapeOutputShape[i] = 1;
89  else if ((lowerRankDim > 1 && higherRankDim == 1) ||
90  (lowerRankDim == higherRankDim))
91  reshapeOutputShape[i] = lowerRankDim;
92  else if (higherRankDim != lowerRankDim)
93  return failure();
94  }
95  return success();
96 }
97 } // namespace
98 
100  Value &input1, Value &input2) {
101  auto input1Ty = llvm::dyn_cast<RankedTensorType>(input1.getType());
102  auto input2Ty = llvm::dyn_cast<RankedTensorType>(input2.getType());
103 
104  if (!input1Ty || !input2Ty) {
105  return failure();
106  }
107 
108  int64_t input1Rank = input1Ty.getRank();
109  int64_t input2Rank = input2Ty.getRank();
110 
111  if (input1Rank == input2Rank)
112  return success();
113 
114  Value higherTensorValue, lowerTensorValue;
115  if (input1Rank > input2Rank) {
116  higherTensorValue = input1;
117  lowerTensorValue = input2;
118  } else {
119  higherTensorValue = input2;
120  lowerTensorValue = input1;
121  }
122 
123  ArrayRef<int64_t> higherRankShape =
124  llvm::cast<RankedTensorType>(higherTensorValue.getType()).getShape();
125  ArrayRef<int64_t> lowerRankShape =
126  llvm::cast<RankedTensorType>(lowerTensorValue.getType()).getShape();
127 
128  SmallVector<int64_t, 4> reshapeOutputShape;
129 
130  if (computeReshapeOutput(higherRankShape, lowerRankShape, reshapeOutputShape)
131  .failed())
132  return failure();
133 
134  auto reshapeInputType =
135  llvm::cast<RankedTensorType>(lowerTensorValue.getType());
136  auto reshapeOutputType = RankedTensorType::get(
137  ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
138 
139  auto reshapeLower = rewriter.create<tosa::ReshapeOp>(
140  loc, reshapeOutputType, lowerTensorValue,
141  rewriter.getDenseI64ArrayAttr(reshapeOutputShape));
142 
143  if (input1Rank > input2Rank) {
144  input1 = higherTensorValue;
145  input2 = reshapeLower.getResult();
146  } else {
147  input1 = reshapeLower.getResult();
148  input2 = higherTensorValue;
149  }
150 
151  return success();
152 }
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:183
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Value clampFloatHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
SmallVector< Value > condenseValues(const SmallVector< Value > &values)
LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc, Value &input1, Value &input2)
Common code to create the reshape op where necessary to make the rank of two values equal.
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
bool validIntegerRange(IntegerType ty, int64_t value)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.