MLIR  20.0.0git
ConversionUtils.cpp
Go to the documentation of this file.
1 //===- ConversionUtils.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Utility functions for TOSA lowering
10 //
11 //===----------------------------------------------------------------------===//
12 
15 
16 using namespace mlir;
17 using namespace mlir::tosa;
18 
20 mlir::tosa::getNParallelLoopsAttrs(unsigned nParallelLoops) {
21  return SmallVector<utils::IteratorType>(nParallelLoops,
22  utils::IteratorType::parallel);
23 }
24 
27  SmallVector<Value> condensedValues;
28  for (auto value : values)
29  if (value)
30  condensedValues.push_back(value);
31  return condensedValues;
32 }
33 
35  Value max, OpBuilder &rewriter) {
36  Value minValue = rewriter.create<arith::MinimumFOp>(loc, arg, max);
37  return rewriter.create<arith::MaximumFOp>(loc, minValue, min);
38 }
39 
41  OpBuilder &rewriter, bool isUnsigned) {
42  if (isUnsigned) {
43  auto minOrArg = rewriter.create<arith::MaxUIOp>(loc, min, arg);
44  return rewriter.create<arith::MinUIOp>(loc, max, minOrArg);
45  }
46  auto minOrArg = rewriter.create<arith::MaxSIOp>(loc, min, arg);
47  return rewriter.create<arith::MinSIOp>(loc, max, minOrArg);
48 }
49 
50 bool mlir::tosa::validIntegerRange(IntegerType ty, int64_t value) {
51  uint64_t bitwidth = ty.getIntOrFloatBitWidth();
52  if (ty.getSignedness() == IntegerType::Unsigned) {
53  uint64_t uvalue = value;
54  APInt intMin = APInt::getMinValue(bitwidth);
55  APInt intMax = APInt::getMaxValue(bitwidth);
56  return uvalue >= intMin.getZExtValue() && uvalue <= intMax.getZExtValue();
57  }
58 
59  APInt intMin = APInt::getSignedMinValue(bitwidth);
60  APInt intMax = APInt::getSignedMaxValue(bitwidth);
61  return value >= intMin.getSExtValue() && value <= intMax.getSExtValue();
62 }
63 
64 namespace {
65 // Given two tensors of high and low ranks, derive the output shape
66 // to reshape the lower rank to.
67 // Examples:
68 // If lower=[c], higher=[a, b, c], [c] reshaped into [1, 1, c].
69 // If lower=[b, c], higher=[a, b, c], [b, c] reshaped into [1, b, c].
70 // If lower=[a], higher=[a, a], [a] reshaped into [1, a].
71 // If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a].
72 // If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1].
73 LogicalResult
74 computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
75  ArrayRef<int64_t> lowerRankShape,
76  SmallVectorImpl<int64_t> &reshapeOutputShape) {
77  // Initialize new shapes with [1] * higherRank.
78  int64_t higherRank = higherRankShape.size();
79  int64_t lowerRank = lowerRankShape.size();
80 
81  reshapeOutputShape.assign(higherRank, 1);
82 
83  int64_t higherRankDim;
84  int64_t lowerRankDim;
85 
86  for (int64_t i = higherRank - 1, j = lowerRank - 1; i >= 0 && j >= 0;
87  i--, j--) {
88  higherRankDim = higherRankShape[i];
89  lowerRankDim = lowerRankShape[j];
90 
91  if (lowerRankDim == 1 && higherRankDim > 1)
92  reshapeOutputShape[i] = 1;
93  else if ((lowerRankDim > 1 && higherRankDim == 1) ||
94  (lowerRankDim == higherRankDim))
95  reshapeOutputShape[i] = lowerRankDim;
96  else if (higherRankDim != lowerRankDim)
97  return failure();
98  }
99  return success();
100 }
101 } // namespace
102 
104  Value &input1, Value &input2) {
105  ImplicitLocOpBuilder builder(loc, rewriter);
106  return EqualizeRanks(builder, input1, input2);
107 }
108 
110  Value &input1, Value &input2) {
111  auto input1Ty = llvm::dyn_cast<RankedTensorType>(input1.getType());
112  auto input2Ty = llvm::dyn_cast<RankedTensorType>(input2.getType());
113 
114  if (!input1Ty || !input2Ty) {
115  return failure();
116  }
117 
118  int64_t input1Rank = input1Ty.getRank();
119  int64_t input2Rank = input2Ty.getRank();
120 
121  if (input1Rank == input2Rank)
122  return success();
123 
124  Value higherTensorValue, lowerTensorValue;
125  if (input1Rank > input2Rank) {
126  higherTensorValue = input1;
127  lowerTensorValue = input2;
128  } else {
129  higherTensorValue = input2;
130  lowerTensorValue = input1;
131  }
132 
133  ArrayRef<int64_t> higherRankShape =
134  llvm::cast<RankedTensorType>(higherTensorValue.getType()).getShape();
135  ArrayRef<int64_t> lowerRankShape =
136  llvm::cast<RankedTensorType>(lowerTensorValue.getType()).getShape();
137 
138  SmallVector<int64_t, 4> reshapeOutputShape;
139 
140  if (computeReshapeOutput(higherRankShape, lowerRankShape, reshapeOutputShape)
141  .failed())
142  return failure();
143 
144  auto reshapeInputType =
145  llvm::cast<RankedTensorType>(lowerTensorValue.getType());
146  auto reshapeOutputType = RankedTensorType::get(
147  ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
148 
149  auto reshapeLower = builder.create<tosa::ReshapeOp>(
150  reshapeOutputType, lowerTensorValue,
151  builder.getDenseI64ArrayAttr(reshapeOutputShape));
152 
153  if (input1Rank > input2Rank) {
154  input1 = higherTensorValue;
155  input2 = reshapeLower.getResult();
156  } else {
157  input1 = reshapeLower.getResult();
158  input2 = higherTensorValue;
159  }
160 
161  return success();
162 }
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:207
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
This class helps build Operations.
Definition: Builders.h:216
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Value clampFloatHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
SmallVector< Value > condenseValues(const SmallVector< Value > &values)
LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc, Value &input1, Value &input2)
Common code to create the reshape op where necessary to make the rank of two values equal.
bool validIntegerRange(IntegerType ty, int64_t value)
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.