MLIR  19.0.0git
TosaToTensor.cpp
Go to the documentation of this file.
1 //===- TosaToTensor.cpp - Lowering Tosa to Tensor Dialect -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the Tensor dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
19 #include "mlir/IR/PatternMatch.h"
21 
22 #include <numeric>
23 
24 using namespace mlir;
25 using namespace tosa;
26 
27 namespace {
28 
29 // Infer the type to which the input of a 'tosa.reshape' op must be cast when
30 // lowered.
31 TensorType inferReshapeInputType(TypedValue<TensorType> input,
32  ArrayRef<int64_t> newShape) {
33  // No need to cast input for non-empty target shape
34  if (!newShape.empty())
35  return input.getType();
36 
37  // The input type must be cast into a tensor with the same rank and all static
38  // dimensions set to 1. This prevents the generation of a tensor.collapse_shape
39  // op that converts a dynamically shaped tensor into a 0D tensor. While such
40  // construct is not incorrect on its own, bufferization cannot properly handle
41  // it at the moment, so we avoid it.
42  SmallVector<int64_t> shape(input.getType().getRank(), 1);
43  return input.getType().clone(shape);
44 }
45 
46 // Infer the result type of 'tensor.expand_shape' in the collapse-expand
47 // pair emitted for a 'tosa.reshape' op.
48 TensorType inferReshapeExpandedType(TensorType inputType,
49  ArrayRef<int64_t> newShape) {
50  // Special case for 0D output tensor. Note: Watch out when using Type::clone()
51  // with just '{}', as it will invoke the incorrect overload.
52  if (newShape.empty())
53  return inputType.clone(ArrayRef<int64_t>{});
54 
55  // Check if the input is static, and if so, get its total size
56  bool inputIsStatic = inputType.hasStaticShape();
57  int64_t totalSize = inputIsStatic ? inputType.getNumElements() : -1;
58 
59  // Compute result shape
60  bool resultIsStatic = true;
61  auto resultShape = llvm::map_to_vector(newShape, [&](int64_t size) -> int64_t {
62  // If this is not a placeholder, do not change it
63  if (size >= 0)
64  return size;
65 
66  // If we do not know the total size of the tensor, keep this dimension
67  // dynamic in the result shape.
68  if (!inputIsStatic) {
69  resultIsStatic = false;
70  return ShapedType::kDynamic;
71  }
72 
73  // Calculate the product of all elements in 'newShape' except for the -1
74  // placeholder, which we discard by negating the result.
75  int64_t totalSizeNoPlaceholder = -std::accumulate(
76  newShape.begin(), newShape.end(), 1, std::multiplies<int64_t>());
77 
78  // If there is a 0 component in 'newShape', resolve the placeholder as 0.
79  if (totalSizeNoPlaceholder == 0)
80  return 0;
81 
82  // Resolve the placeholder as the quotient between the total tensor size and
83  // the product of all other sizes.
84  return totalSize / totalSizeNoPlaceholder;
85  });
86 
87  // A syntactic restriction in 'tensor.expand_shape' forbids a dynamically
88  // shaped input from being reshaped into a statically shaped result. We may
89  // simply turn the first result dimension dynamic to address this.
90  if (!inputIsStatic && resultIsStatic)
91  resultShape[0] = ShapedType::kDynamic;
92 
93  // The 'tensor.expand_shape' op also forbids a statically shaped input from
94  // being reshaped into a dynamically shaped result, but the placeholder
95  // inference algorithm above guarantees that this will never be the case.
96  assert(!inputIsStatic || resultIsStatic);
97 
98  // Create result type
99  return inputType.clone(resultShape);
100 }
101 
102 // Infer the result type of 'tensor.collapse_shape' in the collapse-expand
103 // pair emitted for a 'tosa.reshape' op.
104 TensorType inferReshapeCollapsedType(TensorType lhsType, TensorType rhsType) {
105  auto lhsShape = lhsType.getShape();
106  auto rhsShape = rhsType.getShape();
107 
108  if (lhsShape.empty() || rhsShape.empty())
109  return lhsType.clone(ArrayRef<int64_t>{});
110 
111  if (ShapedType::isDynamicShape(lhsShape) || ShapedType::isDynamicShape(rhsShape))
112  return lhsType.clone({ShapedType::kDynamic});
113 
114  SmallVector<int64_t> intermediateShape;
115  unsigned currLhsDim = 0, currRhsDim = 0;
116  while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) {
117  int64_t rhsSize = rhsShape[currRhsDim];
118  int64_t lhsSize = lhsShape[currLhsDim];
119  while (lhsSize != rhsSize && currLhsDim < lhsShape.size() &&
120  currRhsDim < rhsShape.size()) {
121  if (lhsSize < rhsSize) {
122  currLhsDim++;
123  if (currLhsDim < lhsShape.size()) {
124  lhsSize *= lhsShape[currLhsDim];
125  }
126  } else {
127  currRhsDim++;
128  if (currRhsDim < rhsShape.size()) {
129  rhsSize *= rhsShape[currRhsDim];
130  }
131  }
132  }
133  if (lhsSize == rhsSize) {
134  intermediateShape.push_back(lhsSize);
135  }
136  currRhsDim++;
137  currLhsDim++;
138  }
139 
140  // Static shapes are guaranteed to be compatible by the op verifier, so all
141  // leftover dimensions should be 1.
142  for (; currLhsDim < lhsShape.size(); currLhsDim++) {
143  assert(lhsShape[currLhsDim] == 1);
144  }
145  for (; currRhsDim < rhsShape.size(); currRhsDim++) {
146  assert(rhsShape[currRhsDim] == 1);
147  }
148 
149  return lhsType.clone(intermediateShape);
150 }
151 
153 createReassociationMapForCollapse(OpBuilder &builder, Type srcType, Type dstType) {
154  auto srcShape = cast<TensorType>(srcType).getShape();
155  auto dstShape = cast<TensorType>(dstType).getShape();
156 
157  if (srcShape.empty() || dstShape.empty())
158  return {};
159 
160  if (ShapedType::isDynamicShape(srcShape) || ShapedType::isDynamicShape(dstShape)) {
161  assert(dstShape.size() == 1);
163  for (auto i : llvm::seq<int64_t>(srcShape.size()))
164  exprs.push_back(builder.getAffineDimExpr(i));
165  return {exprs};
166  }
167 
168  SmallVector<ReassociationExprs> reassociationMap(dstShape.size());
169  unsigned currSrcDim = 0, currDstDim = 0;
170  while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
171  int64_t dstSize = dstShape[currDstDim];
172  int64_t srcSize = srcShape[currSrcDim];
173  while (srcSize < dstSize && currSrcDim < srcShape.size()) {
174  reassociationMap[currDstDim].push_back(
175  builder.getAffineDimExpr(currSrcDim++));
176  srcSize *= srcShape[currSrcDim];
177  }
178  if (srcSize == dstSize) {
179  reassociationMap[currDstDim].push_back(
180  builder.getAffineDimExpr(currSrcDim++));
181  // If the next dim in collapsedShape is not 1, treat subsequent dims in
182  // expandedShape which are 1 to be collapsed.
183  if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) {
184  while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
185  reassociationMap[currDstDim].push_back(
186  builder.getAffineDimExpr(currSrcDim++));
187  }
188  }
189  }
190  currDstDim++;
191  }
192 
193  // If the source and target shapes are compatible, both iterators must have
194  // reached the end. This condition is guaranteed by the op verifier for
195  // static shapes.
196  assert(currSrcDim == srcShape.size() && currDstDim == dstShape.size());
197  return reassociationMap;
198 }
199 
200 // Create a tensor.collapse_shape op that reshapes the input into the given
201 // result type.
202 Value createCollapse(OpBuilder &builder, Location loc, TensorType resultType,
203  Value input) {
204  auto reassociationMap =
205  createReassociationMapForCollapse(builder, input.getType(), resultType);
206  return builder.createOrFold<tensor::CollapseShapeOp>(loc, resultType, input,
207  reassociationMap);
208 }
209 
210 // Create a tensor.expand_shape op that reshapes the input into the given result
211 // type.
212 Value createExpand(OpBuilder &builder, Location loc, TensorType resultType,
213  Value input) {
214  auto reassociationMap =
215  createReassociationMapForCollapse(builder, resultType, input.getType());
216  return builder.createOrFold<tensor::ExpandShapeOp>(loc, resultType, input,
217  reassociationMap);
218 }
219 
220 class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
221 public:
223 
225  matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
226  ConversionPatternRewriter &rewriter) const final {
227  auto loc = reshape.getLoc();
228  auto resultType = reshape.getResult().getType();
229  auto input = reshape.getInput1();
230  auto newShape = reshape.getNewShape();
231 
232  // Infer all intermediate types
233  auto inputType = inferReshapeInputType(input, newShape);
234  auto expandedType = inferReshapeExpandedType(inputType, newShape);
235  auto collapsedType = inferReshapeCollapsedType(inputType, expandedType);
236 
237  // Cast input if needed
238  auto castInput = rewriter.createOrFold<tensor::CastOp>(loc, inputType, input);
239 
240  // Emit collaspe-expand pair
241  auto collapsed = createCollapse(rewriter, loc, collapsedType, castInput);
242  auto expanded = createExpand(rewriter, loc, expandedType, collapsed);
243 
244  // Cast to final result type if needed
245  auto result = rewriter.createOrFold<tensor::CastOp>(loc, resultType, expanded);
246  rewriter.replaceOp(reshape, result);
247  return success();
248  }
249 };
250 
251 class SliceConverter : public OpConversionPattern<tosa::SliceOp> {
252 public:
254 
256  matchAndRewrite(tosa::SliceOp sliceOp, OpAdaptor adaptor,
257  ConversionPatternRewriter &rewriter) const final {
258  Location loc = sliceOp.getLoc();
259  Value input = adaptor.getInput();
260  ShapedType resultType = cast<ShapedType>(sliceOp.getType());
261  if (llvm::isa<UnrankedTensorType>(resultType))
262  return failure();
263  SmallVector<int64_t> strides, sizes;
264  ArrayRef<int64_t> starts = sliceOp.getStart();
265  strides.resize(cast<ShapedType>(sliceOp.getType()).getRank(), 1);
266 
267  SmallVector<Value> dynSizes;
268  for (const auto &i : llvm::enumerate(sliceOp.getSize())) {
269  int64_t size = i.value();
270  size_t index = i.index();
271  sizes.push_back(size == -1 ? ShapedType::kDynamic : size);
272  if (!ShapedType::isDynamic(sizes.back()))
273  continue;
274 
275  auto dim = rewriter.create<tensor::DimOp>(loc, input, index);
276  auto offset = rewriter.create<arith::ConstantOp>(
277  loc, rewriter.getIndexAttr(starts[index]));
278  dynSizes.push_back(rewriter.create<arith::SubIOp>(loc, dim, offset));
279  }
280 
281  auto newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
282  sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}), dynSizes,
283  ValueRange({}), rewriter.getDenseI64ArrayAttr(starts),
284  rewriter.getDenseI64ArrayAttr(sizes),
285  rewriter.getDenseI64ArrayAttr(strides));
286 
287  rewriter.replaceOp(sliceOp, newSliceOp.getResult());
288  return success();
289  }
290 };
291 
292 class PadConverter : public OpRewritePattern<tosa::PadOp> {
293 public:
295 
296  LogicalResult matchAndRewrite(tosa::PadOp padOp,
297  PatternRewriter &rewriter) const final {
298  auto loc = padOp.getLoc();
299  auto input = padOp.getInput1();
300  auto padding = padOp.getPadding();
301 
302  ShapedType inputTy = cast<ShapedType>(input.getType());
303  Type elementTy = inputTy.getElementType();
304  int64_t rank = inputTy.getRank();
305 
306  // Setup the default constantAttr.
307 
308  Value padConstant;
309 
310  if (padOp.getPadConst()) {
311  padConstant = rewriter.createOrFold<tensor::ExtractOp>(
312  loc, padOp.getPadConst(), ValueRange({}));
313  } else {
314  TypedAttr constantAttr;
315  if (isa<FloatType>(elementTy)) {
316  constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
317  } else if (isa<IntegerType>(elementTy) && !padOp.getQuantizationInfo()) {
318  constantAttr = rewriter.getIntegerAttr(elementTy, 0);
319  } else if (isa<IntegerType>(elementTy) && padOp.getQuantizationInfo()) {
320  int64_t value = padOp.getQuantizationInfo()->getInputZp();
321  constantAttr = rewriter.getIntegerAttr(elementTy, value);
322  }
323  if (constantAttr)
324  padConstant = rewriter.create<arith::ConstantOp>(loc, constantAttr);
325  }
326 
327  if (!padConstant) {
328  return rewriter.notifyMatchFailure(
329  padOp, "tosa.pad was unable to determine the pad constant value.");
330  }
331 
332  Value lowIndex =
333  rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
334  Value highIndex =
335  rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
336 
338  SmallVector<OpFoldResult, 3> highValues;
339 
340  lowValues.reserve(rank);
341  highValues.reserve(rank);
342 
343  for (int i = 0; i < rank; i++) {
344  Value inputIndex = rewriter.create<arith::ConstantIndexOp>(loc, i);
345  Value lowVal = rewriter.createOrFold<tensor::ExtractOp>(
346  loc, padding, ValueRange({inputIndex, lowIndex}));
347  Value highVal = rewriter.createOrFold<tensor::ExtractOp>(
348  loc, padding, ValueRange({inputIndex, highIndex}));
349 
350  lowVal = rewriter.createOrFold<arith::IndexCastOp>(
351  loc, rewriter.getIndexType(), lowVal);
352  highVal = rewriter.createOrFold<arith::IndexCastOp>(
353  loc, rewriter.getIndexType(), highVal);
354 
355  lowValues.push_back(lowVal);
356  highValues.push_back(highVal);
357  }
358 
359  auto newPadOp = rewriter.create<tensor::PadOp>(
360  loc, padOp.getType(), input, lowValues, highValues, padConstant);
361 
362  rewriter.replaceOp(padOp, newPadOp.getResult());
363  return success();
364  }
365 };
366 
367 struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
369 
371  matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor,
372  ConversionPatternRewriter &rewriter) const override {
373  auto resultType = dyn_cast<RankedTensorType>(op.getType());
374 
375  Location loc = op.getLoc();
376  int axis = op.getAxis();
377  Value axisValue =
378  rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(axis));
379  int64_t rank = resultType.getRank();
380 
381  SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
382  SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
384  tensor::getMixedSizes(rewriter, op.getLoc(), adaptor.getOperands()[0]);
385 
386  // Pre-compute the offsets along the axis dimension.
387  // The axisOffsets will be of size rank + 1, where the last value
388  // will hold the total size of the tensor along the 'axis' dimension.
389  SmallVector<OpFoldResult> axisOffsets;
390  axisOffsets.push_back(rewriter.getIndexAttr(0));
391  axisOffsets.push_back(sizes[axis]);
392 
393  for (auto arg : adaptor.getOperands().drop_front()) {
394  auto size = rewriter.createOrFold<tensor::DimOp>(loc, arg, axisValue);
395  auto currentOffset =
396  getValueOrCreateConstantIndexOp(rewriter, loc, axisOffsets.back());
397  auto total =
398  rewriter.createOrFold<arith::AddIOp>(loc, currentOffset, size);
399  axisOffsets.push_back(getAsOpFoldResult(total));
400  }
401  sizes[axis] = axisOffsets.back();
402 
403  // Compute the dynamic sizes of the tensor.empty operation.
404  // This is based off of the specified result type of the tosa.concat
405  // operation, since we don't want to change the result type of the operation
406  // during the conversion.
407  SmallVector<Value> dynDims;
408  for (int64_t i = 0; i < rank; ++i) {
409  if (resultType.isDynamicDim(i)) {
410  dynDims.push_back(
411  getValueOrCreateConstantIndexOp(rewriter, loc, sizes[i]));
412  }
413  }
414 
415  Value result = rewriter.create<tensor::EmptyOp>(
416  loc, resultType.getShape(), resultType.getElementType(), dynDims);
417 
418  for (auto [arg, offset] : llvm::zip(adaptor.getOperands(), axisOffsets)) {
419  auto sizes = tensor::getMixedSizes(rewriter, op.getLoc(), arg);
420  offsets[axis] = offset;
421  result = rewriter.createOrFold<tensor::InsertSliceOp>(
422  loc, arg, result, offsets, sizes, strides);
423  }
424  rewriter.replaceOp(op, result);
425  return success();
426  }
427 };
428 
429 } // namespace
430 
432  RewritePatternSet *patterns) {
433  patterns->add<
434  ConcatConverter,
435  PadConverter,
436  ReshapeConverter,
437  SliceConverter
438  >(patterns->getContext());
439 }
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:371
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:209
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:522
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:91
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
RankedTensorType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:61
void populateTosaToTensorConversionPatterns(RewritePatternSet *patterns)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:41
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358