MLIR  20.0.0git
TosaToTensor.cpp
Go to the documentation of this file.
1 //===- TosaToTensor.cpp - Lowering Tosa to Tensor Dialect -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the Tensor dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
19 #include "mlir/IR/PatternMatch.h"
21 
22 #include <numeric>
23 
24 using namespace mlir;
25 using namespace tosa;
26 
27 namespace {
28 
29 // Infer the type to which the input of a 'tosa.reshape' op must be cast when
30 // lowered.
31 TensorType inferReshapeInputType(TypedValue<TensorType> input,
32  ArrayRef<int64_t> newShape) {
33  // No need to cast input for non-empty target shape
34  if (!newShape.empty())
35  return input.getType();
36 
37  // The input type must be cast into a tensor with the same rank and all static
38  // dimensions set to 1. This prevents the generation of a tensor.collapse_shape
39  // op that converts a dynamically shaped tensor into a 0D tensor. While such
40  // construct is not incorrect on its own, bufferization cannot properly handle
41  // it at the moment, so we avoid it.
42  SmallVector<int64_t> shape(input.getType().getRank(), 1);
43  return input.getType().clone(shape);
44 }
45 
46 // Infer the result type of 'tensor.expand_shape' in the collapse-expand
47 // pair emitted for a 'tosa.reshape' op.
48 TensorType inferReshapeExpandedType(TensorType inputType,
49  ArrayRef<int64_t> newShape) {
50  // Special case for 0D output tensor. Note: Watch out when using Type::clone()
51  // with just '{}', as it will invoke the incorrect overload.
52  if (newShape.empty())
53  return inputType.clone(ArrayRef<int64_t>{});
54 
55  // Check if the input is static, and if so, get its total size
56  bool inputIsStatic = inputType.hasStaticShape();
57  int64_t totalSize = inputIsStatic ? inputType.getNumElements() : -1;
58 
59  // Compute result shape
60  auto resultShape = llvm::map_to_vector(newShape, [&](int64_t size) -> int64_t {
61  // If this is not a placeholder, do not change it
62  if (size >= 0)
63  return size;
64 
65  // If we do not know the total size of the tensor, keep this dimension
66  // dynamic in the result shape.
67  if (!inputIsStatic)
68  return ShapedType::kDynamic;
69 
70  // Calculate the product of all elements in 'newShape' except for the -1
71  // placeholder, which we discard by negating the result.
72  int64_t totalSizeNoPlaceholder = -std::accumulate(
73  newShape.begin(), newShape.end(), 1, std::multiplies<int64_t>());
74 
75  // If there is a 0 component in 'newShape', resolve the placeholder as 0.
76  if (totalSizeNoPlaceholder == 0)
77  return 0;
78 
79  // Resolve the placeholder as the quotient between the total tensor size and
80  // the product of all other sizes.
81  return totalSize / totalSizeNoPlaceholder;
82  });
83 
84  bool resultIsStatic = !ShapedType::isDynamicShape(resultShape);
85 
86  // A syntactic restriction in 'tensor.expand_shape' forbids a dynamically
87  // shaped input from being reshaped into a statically shaped result. We may
88  // simply turn the first result dimension dynamic to address this.
89  if (!inputIsStatic && resultIsStatic)
90  resultShape[0] = ShapedType::kDynamic;
91 
92  // The 'tensor.expand_shape' op also forbids a statically shaped input from
93  // being reshaped into a dynamically shaped result, but the placeholder
94  // inference algorithm above guarantees that this will never be the case.
95  assert(!inputIsStatic || resultIsStatic);
96 
97  // Create result type
98  return inputType.clone(resultShape);
99 }
100 
101 // Infer the result type of 'tensor.collapse_shape' in the collapse-expand
102 // pair emitted for a 'tosa.reshape' op.
103 TensorType inferReshapeCollapsedType(TensorType lhsType, TensorType rhsType) {
104  auto lhsShape = lhsType.getShape();
105  auto rhsShape = rhsType.getShape();
106 
107  if (lhsShape.empty() || rhsShape.empty())
108  return lhsType.clone(ArrayRef<int64_t>{});
109 
110  if (ShapedType::isDynamicShape(lhsShape) || ShapedType::isDynamicShape(rhsShape))
111  return lhsType.clone({ShapedType::kDynamic});
112 
113  SmallVector<int64_t> intermediateShape;
114  unsigned currLhsDim = 0, currRhsDim = 0;
115  while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) {
116  int64_t rhsSize = rhsShape[currRhsDim];
117  int64_t lhsSize = lhsShape[currLhsDim];
118  while (lhsSize != rhsSize && currLhsDim < lhsShape.size() &&
119  currRhsDim < rhsShape.size()) {
120  if (lhsSize < rhsSize) {
121  currLhsDim++;
122  if (currLhsDim < lhsShape.size()) {
123  lhsSize *= lhsShape[currLhsDim];
124  }
125  } else {
126  currRhsDim++;
127  if (currRhsDim < rhsShape.size()) {
128  rhsSize *= rhsShape[currRhsDim];
129  }
130  }
131  }
132  if (lhsSize == rhsSize) {
133  intermediateShape.push_back(lhsSize);
134  }
135  currRhsDim++;
136  currLhsDim++;
137  }
138 
139  // Static shapes are guaranteed to be compatible by the op verifier, so all
140  // leftover dimensions should be 1.
141  for (; currLhsDim < lhsShape.size(); currLhsDim++) {
142  assert(lhsShape[currLhsDim] == 1);
143  }
144  for (; currRhsDim < rhsShape.size(); currRhsDim++) {
145  assert(rhsShape[currRhsDim] == 1);
146  }
147 
148  return lhsType.clone(intermediateShape);
149 }
150 
152 createReassociationMapForCollapse(OpBuilder &builder, Type srcType, Type dstType) {
153  auto srcShape = cast<TensorType>(srcType).getShape();
154  auto dstShape = cast<TensorType>(dstType).getShape();
155 
156  if (srcShape.empty() || dstShape.empty())
157  return {};
158 
159  if (ShapedType::isDynamicShape(srcShape) || ShapedType::isDynamicShape(dstShape)) {
160  assert(dstShape.size() == 1);
162  for (auto i : llvm::seq<int64_t>(srcShape.size()))
163  exprs.push_back(builder.getAffineDimExpr(i));
164  return {exprs};
165  }
166 
167  SmallVector<ReassociationExprs> reassociationMap(dstShape.size());
168  unsigned currSrcDim = 0, currDstDim = 0;
169  while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
170  int64_t dstSize = dstShape[currDstDim];
171  int64_t srcSize = srcShape[currSrcDim];
172  while (srcSize < dstSize && currSrcDim < srcShape.size()) {
173  reassociationMap[currDstDim].push_back(
174  builder.getAffineDimExpr(currSrcDim++));
175  srcSize *= srcShape[currSrcDim];
176  }
177  if (srcSize == dstSize) {
178  reassociationMap[currDstDim].push_back(
179  builder.getAffineDimExpr(currSrcDim++));
180  // If the next dim in collapsedShape is not 1, treat subsequent dims in
181  // expandedShape which are 1 to be collapsed.
182  if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) {
183  while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
184  reassociationMap[currDstDim].push_back(
185  builder.getAffineDimExpr(currSrcDim++));
186  }
187  }
188  }
189  currDstDim++;
190  }
191 
192  // If the source and target shapes are compatible, both iterators must have
193  // reached the end. This condition is guaranteed by the op verifier for
194  // static shapes.
195  assert(currSrcDim == srcShape.size() && currDstDim == dstShape.size());
196  return reassociationMap;
197 }
198 
199 // Create a tensor.collapse_shape op that reshapes the input into the given
200 // result type.
201 Value createCollapse(OpBuilder &builder, Location loc, TensorType resultType,
202  Value input) {
203  auto reassociationMap =
204  createReassociationMapForCollapse(builder, input.getType(), resultType);
205  return builder.createOrFold<tensor::CollapseShapeOp>(loc, resultType, input,
206  reassociationMap);
207 }
208 
209 // Create a tensor.expand_shape op that reshapes the input into the given result
210 // type.
211 Value createExpand(OpBuilder &builder, Location loc, TensorType resultType,
212  Value input) {
213  auto reassociationMap =
214  createReassociationMapForCollapse(builder, resultType, input.getType());
215  return builder.createOrFold<tensor::ExpandShapeOp>(loc, resultType, input,
216  reassociationMap);
217 }
218 
219 class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
220 public:
222 
223  LogicalResult
224  matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
225  ConversionPatternRewriter &rewriter) const final {
226  auto loc = reshape.getLoc();
227  auto resultType = cast_if_present<ShapedType>(
228  getTypeConverter()->convertType(reshape.getType()));
229  if (!resultType) {
230  return rewriter.notifyMatchFailure(reshape.getLoc(),
231  "could not convert result type");
232  }
233  auto input = dyn_cast<TypedValue<TensorType>>(adaptor.getInput1());
234  if (!input) {
235  return rewriter.notifyMatchFailure(reshape.getLoc(),
236  "expected input type to be tensor");
237  }
238  auto newShape = reshape.getNewShape();
239 
240  // Infer all intermediate types
241  auto inputType = inferReshapeInputType(input, newShape);
242  auto expandedType = inferReshapeExpandedType(inputType, newShape);
243  auto collapsedType = inferReshapeCollapsedType(inputType, expandedType);
244 
245  // Cast input if needed
246  auto castInput = rewriter.createOrFold<tensor::CastOp>(loc, inputType, input);
247 
248  // Emit collaspe-expand pair
249  auto collapsed = createCollapse(rewriter, loc, collapsedType, castInput);
250  auto expanded = createExpand(rewriter, loc, expandedType, collapsed);
251 
252  // Cast to final result type if needed
253  auto result = rewriter.createOrFold<tensor::CastOp>(loc, resultType, expanded);
254  rewriter.replaceOp(reshape, result);
255  return success();
256  }
257 };
258 
259 class SliceConverter : public OpConversionPattern<tosa::SliceOp> {
260 public:
262 
263  LogicalResult
264  matchAndRewrite(tosa::SliceOp sliceOp, OpAdaptor adaptor,
265  ConversionPatternRewriter &rewriter) const final {
266  Location loc = sliceOp.getLoc();
267  Value input = adaptor.getInput1();
268  ShapedType resultType = cast<ShapedType>(sliceOp.getType());
269  if (llvm::isa<UnrankedTensorType>(resultType))
270  return failure();
271  SmallVector<int64_t> strides, sizes;
272  ArrayRef<int64_t> starts = sliceOp.getStart();
273  strides.resize(cast<ShapedType>(sliceOp.getType()).getRank(), 1);
274 
275  SmallVector<Value> dynSizes;
276  for (const auto &i : llvm::enumerate(sliceOp.getSize())) {
277  int64_t size = i.value();
278  size_t index = i.index();
279  sizes.push_back(size == -1 ? ShapedType::kDynamic : size);
280  if (!ShapedType::isDynamic(sizes.back()))
281  continue;
282 
283  auto dim = rewriter.create<tensor::DimOp>(loc, input, index);
284  auto offset = rewriter.create<arith::ConstantOp>(
285  loc, rewriter.getIndexAttr(starts[index]));
286  dynSizes.push_back(rewriter.create<arith::SubIOp>(loc, dim, offset));
287  }
288 
289  auto newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
290  sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}), dynSizes,
291  ValueRange({}), rewriter.getDenseI64ArrayAttr(starts),
292  rewriter.getDenseI64ArrayAttr(sizes),
293  rewriter.getDenseI64ArrayAttr(strides));
294 
295  rewriter.replaceOp(sliceOp, newSliceOp.getResult());
296  return success();
297  }
298 };
299 
300 class PadConverter : public OpConversionPattern<tosa::PadOp> {
301 public:
303 
304  LogicalResult
305  matchAndRewrite(tosa::PadOp padOp, OpAdaptor adaptor,
306  ConversionPatternRewriter &rewriter) const final {
307  auto loc = padOp.getLoc();
308  auto input = padOp.getInput1();
309  auto padding = padOp.getPadding();
310 
311  ShapedType inputTy = cast<ShapedType>(input.getType());
312  Type elementTy = inputTy.getElementType();
313  int64_t rank = inputTy.getRank();
314 
315  // Setup the default constantAttr.
316 
317  Value padConstant;
318 
319  if (padOp.getPadConst()) {
320  padConstant = rewriter.createOrFold<tensor::ExtractOp>(
321  loc, padOp.getPadConst(), ValueRange({}));
322  } else {
323  TypedAttr constantAttr;
324  if (isa<FloatType>(elementTy)) {
325  constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
326  } else if (isa<IntegerType>(elementTy) && !padOp.getQuantizationInfo()) {
327  constantAttr = rewriter.getIntegerAttr(elementTy, 0);
328  } else if (isa<IntegerType>(elementTy) && padOp.getQuantizationInfo()) {
329  int64_t value = padOp.getQuantizationInfo()->getInputZp();
330  constantAttr = rewriter.getIntegerAttr(elementTy, value);
331  }
332  if (constantAttr)
333  padConstant = rewriter.create<arith::ConstantOp>(loc, constantAttr);
334  }
335 
336  if (!padConstant) {
337  return rewriter.notifyMatchFailure(
338  padOp, "tosa.pad was unable to determine the pad constant value.");
339  }
340 
341  Value lowIndex =
342  rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
343  Value highIndex =
344  rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
345 
347  SmallVector<OpFoldResult, 3> highValues;
348 
349  lowValues.reserve(rank);
350  highValues.reserve(rank);
351 
352  for (int i = 0; i < rank; i++) {
353  Value inputIndex = rewriter.create<arith::ConstantIndexOp>(loc, i);
354  Value lowVal = rewriter.createOrFold<tensor::ExtractOp>(
355  loc, padding, ValueRange({inputIndex, lowIndex}));
356  Value highVal = rewriter.createOrFold<tensor::ExtractOp>(
357  loc, padding, ValueRange({inputIndex, highIndex}));
358 
359  lowVal = rewriter.createOrFold<arith::IndexCastOp>(
360  loc, rewriter.getIndexType(), lowVal);
361  highVal = rewriter.createOrFold<arith::IndexCastOp>(
362  loc, rewriter.getIndexType(), highVal);
363 
364  lowValues.push_back(lowVal);
365  highValues.push_back(highVal);
366  }
367 
368  auto newPadOp = rewriter.create<tensor::PadOp>(
369  loc, padOp.getType(), input, lowValues, highValues, padConstant);
370 
371  rewriter.replaceOp(padOp, newPadOp.getResult());
372  return success();
373  }
374 };
375 
376 struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
378 
379  LogicalResult
380  matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor,
381  ConversionPatternRewriter &rewriter) const override {
382  auto resultType = dyn_cast<RankedTensorType>(op.getType());
383 
384  Location loc = op.getLoc();
385  int axis = op.getAxis();
386  Value axisValue =
387  rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(axis));
388  int64_t rank = resultType.getRank();
389 
390  SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
391  SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
393  tensor::getMixedSizes(rewriter, op.getLoc(), adaptor.getOperands()[0]);
394 
395  // Pre-compute the offsets along the axis dimension.
396  // The axisOffsets will be of size rank + 1, where the last value
397  // will hold the total size of the tensor along the 'axis' dimension.
398  SmallVector<OpFoldResult> axisOffsets;
399  axisOffsets.push_back(rewriter.getIndexAttr(0));
400  axisOffsets.push_back(sizes[axis]);
401 
402  for (auto arg : adaptor.getOperands().drop_front()) {
403  auto size = rewriter.createOrFold<tensor::DimOp>(loc, arg, axisValue);
404  auto currentOffset =
405  getValueOrCreateConstantIndexOp(rewriter, loc, axisOffsets.back());
406  auto total =
407  rewriter.createOrFold<arith::AddIOp>(loc, currentOffset, size);
408  axisOffsets.push_back(getAsOpFoldResult(total));
409  }
410  sizes[axis] = axisOffsets.back();
411 
412  // Compute the dynamic sizes of the tensor.empty operation.
413  // This is based off of the specified result type of the tosa.concat
414  // operation, since we don't want to change the result type of the operation
415  // during the conversion.
416  SmallVector<Value> dynDims;
417  for (int64_t i = 0; i < rank; ++i) {
418  if (resultType.isDynamicDim(i)) {
419  dynDims.push_back(
420  getValueOrCreateConstantIndexOp(rewriter, loc, sizes[i]));
421  }
422  }
423 
424  Value result = rewriter.create<tensor::EmptyOp>(
425  loc, resultType.getShape(), resultType.getElementType(), dynDims);
426 
427  for (auto [arg, offset] : llvm::zip(adaptor.getOperands(), axisOffsets)) {
428  auto sizes = tensor::getMixedSizes(rewriter, op.getLoc(), arg);
429  offsets[axis] = offset;
430  result = rewriter.createOrFold<tensor::InsertSliceOp>(
431  loc, arg, result, offsets, sizes, strides);
432  }
433  rewriter.replaceOp(op, result);
434  return success();
435  }
436 };
437 
438 } // namespace
439 
441  const TypeConverter &converter, RewritePatternSet *patterns) {
442  patterns
443  ->add<ConcatConverter, PadConverter, ReshapeConverter, SliceConverter>(
444  converter, patterns->getContext());
445 }
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:404
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
This class helps build Operations.
Definition: Builders.h:216
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:529
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:102
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
RankedTensorType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
Type conversion class.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:66
void populateTosaToTensorConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns)
Include the generated interface declarations.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.