MLIR  20.0.0git
ConversionUtils.cpp
Go to the documentation of this file.
1 //===- ConversionUtils.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Utility functions for TOSA lowering
10 //
11 //===----------------------------------------------------------------------===//
12 
15 
16 using namespace mlir;
17 using namespace mlir::tosa;
18 
20 mlir::tosa::getNParallelLoopsAttrs(unsigned nParallelLoops) {
21  return SmallVector<utils::IteratorType>(nParallelLoops,
22  utils::IteratorType::parallel);
23 }
24 
27  SmallVector<Value> condensedValues;
28  for (auto value : values)
29  if (value)
30  condensedValues.push_back(value);
31  return condensedValues;
32 }
33 
35  Value max, OpBuilder &rewriter) {
36  Value minValue = rewriter.create<arith::MinimumFOp>(loc, arg, max);
37  return rewriter.create<arith::MaximumFOp>(loc, minValue, min);
38 }
39 
41  OpBuilder &rewriter, bool isUnsigned) {
42  if (isUnsigned) {
43  auto minOrArg = rewriter.create<arith::MaxUIOp>(loc, min, arg);
44  return rewriter.create<arith::MinUIOp>(loc, max, minOrArg);
45  }
46  auto minOrArg = rewriter.create<arith::MaxSIOp>(loc, min, arg);
47  return rewriter.create<arith::MinSIOp>(loc, max, minOrArg);
48 }
49 
50 bool mlir::tosa::validIntegerRange(IntegerType ty, int64_t value) {
51  uint64_t bitwidth = ty.getIntOrFloatBitWidth();
52  if (ty.getSignedness() == IntegerType::Unsigned) {
53  uint64_t uvalue = value;
54  APInt intMin = APInt::getMinValue(bitwidth);
55  APInt intMax = APInt::getMaxValue(bitwidth);
56  return uvalue >= intMin.getZExtValue() && uvalue <= intMax.getZExtValue();
57  }
58 
59  APInt intMin = APInt::getSignedMinValue(bitwidth);
60  APInt intMax = APInt::getSignedMaxValue(bitwidth);
61  return value >= intMin.getSExtValue() && value <= intMax.getSExtValue();
62 }
63 
64 namespace {
65 // Given two tensors of high and low ranks, derive the output shape
66 // to reshape the lower rank to.
67 // Examples:
68 // If lower=[c], higher=[a, b, c], [c] reshaped into [1, 1, c].
69 // If lower=[b, c], higher=[a, b, c], [b, c] reshaped into [1, b, c].
70 // If lower=[a], higher=[a, a], [a] reshaped into [1, a].
71 // If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a].
72 // If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1].
73 LogicalResult
74 computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
75  ArrayRef<int64_t> lowerRankShape,
76  SmallVectorImpl<int64_t> &reshapeOutputShape) {
77  // Initialize new shapes with [1] * higherRank.
78  int64_t higherRank = higherRankShape.size();
79  int64_t lowerRank = lowerRankShape.size();
80 
81  reshapeOutputShape.assign(higherRank, 1);
82 
83  int64_t higherRankDim;
84  int64_t lowerRankDim;
85 
86  for (int64_t i = higherRank - 1, j = lowerRank - 1; i >= 0 && j >= 0;
87  i--, j--) {
88  higherRankDim = higherRankShape[i];
89  lowerRankDim = lowerRankShape[j];
90 
91  if (lowerRankDim == 1 && higherRankDim > 1)
92  reshapeOutputShape[i] = 1;
93  else if ((lowerRankDim > 1 && higherRankDim == 1) ||
94  (lowerRankDim == higherRankDim))
95  reshapeOutputShape[i] = lowerRankDim;
96  else if (higherRankDim != lowerRankDim)
97  return failure();
98  }
99  return success();
100 }
101 } // namespace
102 
104  Value &input1, Value &input2) {
105  auto input1Ty = llvm::dyn_cast<RankedTensorType>(input1.getType());
106  auto input2Ty = llvm::dyn_cast<RankedTensorType>(input2.getType());
107 
108  if (!input1Ty || !input2Ty) {
109  return failure();
110  }
111 
112  int64_t input1Rank = input1Ty.getRank();
113  int64_t input2Rank = input2Ty.getRank();
114 
115  if (input1Rank == input2Rank)
116  return success();
117 
118  Value higherTensorValue, lowerTensorValue;
119  if (input1Rank > input2Rank) {
120  higherTensorValue = input1;
121  lowerTensorValue = input2;
122  } else {
123  higherTensorValue = input2;
124  lowerTensorValue = input1;
125  }
126 
127  ArrayRef<int64_t> higherRankShape =
128  llvm::cast<RankedTensorType>(higherTensorValue.getType()).getShape();
129  ArrayRef<int64_t> lowerRankShape =
130  llvm::cast<RankedTensorType>(lowerTensorValue.getType()).getShape();
131 
132  SmallVector<int64_t, 4> reshapeOutputShape;
133 
134  if (computeReshapeOutput(higherRankShape, lowerRankShape, reshapeOutputShape)
135  .failed())
136  return failure();
137 
138  auto reshapeInputType =
139  llvm::cast<RankedTensorType>(lowerTensorValue.getType());
140  auto reshapeOutputType = RankedTensorType::get(
141  ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
142 
143  auto reshapeLower = rewriter.create<tosa::ReshapeOp>(
144  loc, reshapeOutputType, lowerTensorValue,
145  rewriter.getDenseI64ArrayAttr(reshapeOutputShape));
146 
147  if (input1Rank > input2Rank) {
148  input1 = higherTensorValue;
149  input2 = reshapeLower.getResult();
150  } else {
151  input1 = reshapeLower.getResult();
152  input2 = higherTensorValue;
153  }
154 
155  return success();
156 }
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:187
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:210
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Value clampFloatHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
SmallVector< Value > condenseValues(const SmallVector< Value > &values)
LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc, Value &input1, Value &input2)
Common code to create the reshape op where necessary to make the rank of two values equal.
bool validIntegerRange(IntegerType ty, int64_t value)
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.