MLIR  18.0.0git
ConversionUtils.cpp
Go to the documentation of this file.
1 //===- ConversionUtils.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Utility functions for TOSA lowering
10 //
11 //===----------------------------------------------------------------------===//
12 
15 
16 using namespace mlir;
17 using namespace mlir::tosa;
18 
20 mlir::tosa::getNParallelLoopsAttrs(unsigned nParallelLoops) {
21  return SmallVector<utils::IteratorType>(nParallelLoops,
22  utils::IteratorType::parallel);
23 }
24 
27  SmallVector<Value> condensedValues;
28  for (auto value : values)
29  if (value)
30  condensedValues.push_back(value);
31  return condensedValues;
32 }
33 
35  Value max, OpBuilder &rewriter) {
36  Value minValue = rewriter.create<arith::MinimumFOp>(loc, arg, max);
37  return rewriter.create<arith::MaximumFOp>(loc, minValue, min);
38 }
39 
41  OpBuilder &rewriter) {
42  auto smallerThanMin =
43  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, arg, min);
44  auto minOrArg =
45  rewriter.create<arith::SelectOp>(loc, smallerThanMin, min, arg);
46  auto largerThanMax =
47  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, max, arg);
48  return rewriter.create<arith::SelectOp>(loc, largerThanMax, max, minOrArg);
49 }
50 
51 bool mlir::tosa::validIntegerRange(IntegerType ty, int64_t value) {
52  uint64_t bitwidth = ty.getIntOrFloatBitWidth();
53  if (ty.getSignedness() == IntegerType::Unsigned) {
54  uint64_t uvalue = value;
55  APInt intMin = APInt::getMinValue(bitwidth);
56  APInt intMax = APInt::getMaxValue(bitwidth);
57  return uvalue >= intMin.getZExtValue() && uvalue <= intMax.getZExtValue();
58  }
59 
60  APInt intMin = APInt::getSignedMinValue(bitwidth);
61  APInt intMax = APInt::getSignedMaxValue(bitwidth);
62  return value >= intMin.getSExtValue() && value <= intMax.getSExtValue();
63 }
64 
65 namespace {
66 // Given two tensors of high and low ranks, derive the output shape
67 // to reshape the lower rank to.
68 // Examples:
69 // If lower=[c], higher=[a, b, c], [c] reshaped into [1, 1, c].
70 // If lower=[b, c], higher=[a, b, c], [b, c] reshaped into [1, b, c].
71 // If lower=[a], higher=[a, a], [a] reshaped into [1, a].
72 // If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a].
73 // If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1].
75 computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
76  ArrayRef<int64_t> lowerRankShape,
77  SmallVectorImpl<int64_t> &reshapeOutputShape) {
78  // Initialize new shapes with [1] * higherRank.
79  int64_t higherRank = higherRankShape.size();
80  int64_t lowerRank = lowerRankShape.size();
81 
82  reshapeOutputShape.assign(higherRank, 1);
83 
84  int64_t higherRankDim;
85  int64_t lowerRankDim;
86 
87  for (int64_t i = higherRank - 1, j = lowerRank - 1; i >= 0 && j >= 0;
88  i--, j--) {
89  higherRankDim = higherRankShape[i];
90  lowerRankDim = lowerRankShape[j];
91 
92  if (lowerRankDim == 1 && higherRankDim > 1)
93  reshapeOutputShape[i] = 1;
94  else if ((lowerRankDim > 1 && higherRankDim == 1) ||
95  (lowerRankDim == higherRankDim))
96  reshapeOutputShape[i] = lowerRankDim;
97  else if (higherRankDim != lowerRankDim)
98  return failure();
99  }
100  return success();
101 }
102 } // namespace
103 
105  Value &input1, Value &input2) {
106  auto input1Ty = llvm::dyn_cast<RankedTensorType>(input1.getType());
107  auto input2Ty = llvm::dyn_cast<RankedTensorType>(input2.getType());
108 
109  if (!input1Ty || !input2Ty) {
110  return failure();
111  }
112 
113  int64_t input1Rank = input1Ty.getRank();
114  int64_t input2Rank = input2Ty.getRank();
115 
116  if (input1Rank == input2Rank)
117  return success();
118 
119  Value higherTensorValue, lowerTensorValue;
120  if (input1Rank > input2Rank) {
121  higherTensorValue = input1;
122  lowerTensorValue = input2;
123  } else {
124  higherTensorValue = input2;
125  lowerTensorValue = input1;
126  }
127 
128  ArrayRef<int64_t> higherRankShape =
129  llvm::cast<RankedTensorType>(higherTensorValue.getType()).getShape();
130  ArrayRef<int64_t> lowerRankShape =
131  llvm::cast<RankedTensorType>(lowerTensorValue.getType()).getShape();
132 
133  SmallVector<int64_t, 4> reshapeOutputShape;
134 
135  if (computeReshapeOutput(higherRankShape, lowerRankShape, reshapeOutputShape)
136  .failed())
137  return failure();
138 
139  auto reshapeInputType =
140  llvm::cast<RankedTensorType>(lowerTensorValue.getType());
141  auto reshapeOutputType = RankedTensorType::get(
142  ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
143 
144  auto reshapeLower = rewriter.create<tosa::ReshapeOp>(
145  loc, reshapeOutputType, lowerTensorValue,
146  rewriter.getDenseI64ArrayAttr(reshapeOutputShape));
147 
148  if (input1Rank > input2Rank) {
149  input1 = higherTensorValue;
150  input2 = reshapeLower.getResult();
151  } else {
152  input1 = reshapeLower.getResult();
153  input2 = higherTensorValue;
154  }
155 
156  return success();
157 }
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:183
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:206
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Value clampFloatHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
SmallVector< Value > condenseValues(const SmallVector< Value > &values)
LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc, Value &input1, Value &input2)
Common code to create the reshape op where necessary to make the rank of two values equal.
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
bool validIntegerRange(IntegerType ty, int64_t value)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.