MLIR  21.0.0git
TosaToTensor.cpp
Go to the documentation of this file.
1 //===- TosaToTensor.cpp - Lowering Tosa to Tensor Dialect -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the Tensor dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
20 #include "mlir/IR/PatternMatch.h"
22 
23 #include <numeric>
24 
25 using namespace mlir;
26 using namespace tosa;
27 
28 namespace {
29 
30 // Infer the type to which the input of a 'tosa.reshape' op must be cast when
31 // lowered.
32 TensorType inferReshapeInputType(TypedValue<TensorType> input,
33  ArrayRef<int64_t> newShape) {
34  // No need to cast input for non-empty target shape
35  if (!newShape.empty())
36  return input.getType();
37 
38  // The input type must be cast into a tensor with the same rank and all static
39  // dimensions set to 1. This prevents the generation of a
40  // tensor.collapse_shape op that converts a dynamically shaped tensor into a
41  // 0D tensor. While such construct is not incorrect on its own, bufferization
42  // cannot properly handle it at the moment, so we avoid it.
43  SmallVector<int64_t> shape(input.getType().getRank(), 1);
44  return input.getType().clone(shape);
45 }
46 
47 // Infer the result type of 'tensor.expand_shape' in the collapse-expand
48 // pair emitted for a 'tosa.reshape' op.
49 TensorType inferReshapeExpandedType(TensorType inputType,
50  ArrayRef<int64_t> newShape) {
51  // Special case for 0D output tensor. Note: Watch out when using Type::clone()
52  // with just '{}', as it will invoke the incorrect overload.
53  if (newShape.empty())
54  return inputType.clone(ArrayRef<int64_t>{});
55 
56  // Check if the input is static, and if so, get its total size
57  bool inputIsStatic = inputType.hasStaticShape();
58  int64_t totalSize = inputIsStatic ? inputType.getNumElements() : -1;
59 
60  // Compute result shape
61  auto resultShape =
62  llvm::map_to_vector(newShape, [&](int64_t size) -> int64_t {
63  // If this is not a placeholder, do not change it.
64  if (size >= 0)
65  return size;
66 
67  // If we do not know the total size of the tensor, keep this dimension
68  // dynamic in the result shape.
69  if (!inputIsStatic)
70  return ShapedType::kDynamic;
71 
72  // Calculate the product of all elements in 'newShape' except for the -1
73  // placeholder, which we discard by negating the result.
74  int64_t totalSizeNoPlaceholder = -std::accumulate(
75  newShape.begin(), newShape.end(), 1, std::multiplies<int64_t>());
76 
77  // If there is a 0 component in 'newShape', resolve the placeholder as
78  // 0.
79  if (totalSizeNoPlaceholder == 0)
80  return 0;
81 
82  // Resolve the placeholder as the quotient between the total tensor size
83  // and the product of all other sizes.
84  return totalSize / totalSizeNoPlaceholder;
85  });
86 
87  bool resultIsStatic = !ShapedType::isDynamicShape(resultShape);
88 
89  // A syntactic restriction in 'tensor.expand_shape' forbids a dynamically
90  // shaped input from being reshaped into a statically shaped result. We may
91  // simply turn the first result dimension dynamic to address this.
92  if (!inputIsStatic && resultIsStatic)
93  resultShape[0] = ShapedType::kDynamic;
94 
95  // The 'tensor.expand_shape' op also forbids a statically shaped input from
96  // being reshaped into a dynamically shaped result, but the placeholder
97  // inference algorithm above guarantees that this will never be the case.
98  assert(!inputIsStatic || resultIsStatic);
99 
100  // Create result type
101  return inputType.clone(resultShape);
102 }
103 
104 // Infer the result type of 'tensor.collapse_shape' in the collapse-expand
105 // pair emitted for a 'tosa.reshape' op.
106 TensorType inferReshapeCollapsedType(TensorType lhsType, TensorType rhsType) {
107  auto lhsShape = lhsType.getShape();
108  auto rhsShape = rhsType.getShape();
109 
110  if (lhsShape.empty() || rhsShape.empty())
111  return lhsType.clone(ArrayRef<int64_t>{});
112 
113  if (ShapedType::isDynamicShape(lhsShape) ||
114  ShapedType::isDynamicShape(rhsShape))
115  return lhsType.clone({ShapedType::kDynamic});
116 
117  SmallVector<int64_t> intermediateShape;
118  unsigned currLhsDim = 0, currRhsDim = 0;
119  while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) {
120  int64_t rhsSize = rhsShape[currRhsDim];
121  int64_t lhsSize = lhsShape[currLhsDim];
122  while (lhsSize != rhsSize && currLhsDim < lhsShape.size() &&
123  currRhsDim < rhsShape.size()) {
124  if (lhsSize < rhsSize) {
125  currLhsDim++;
126  if (currLhsDim < lhsShape.size()) {
127  lhsSize *= lhsShape[currLhsDim];
128  }
129  } else {
130  currRhsDim++;
131  if (currRhsDim < rhsShape.size()) {
132  rhsSize *= rhsShape[currRhsDim];
133  }
134  }
135  }
136  if (lhsSize == rhsSize) {
137  intermediateShape.push_back(lhsSize);
138  }
139  currRhsDim++;
140  currLhsDim++;
141  }
142 
143  // Static shapes are guaranteed to be compatible by the op verifier, so all
144  // leftover dimensions should be 1.
145  for (; currLhsDim < lhsShape.size(); currLhsDim++) {
146  assert(lhsShape[currLhsDim] == 1);
147  }
148  for (; currRhsDim < rhsShape.size(); currRhsDim++) {
149  assert(rhsShape[currRhsDim] == 1);
150  }
151 
152  return lhsType.clone(intermediateShape);
153 }
154 
156 createReassociationMapForCollapse(OpBuilder &builder, Type srcType,
157  Type dstType) {
158  auto srcShape = cast<TensorType>(srcType).getShape();
159  auto dstShape = cast<TensorType>(dstType).getShape();
160 
161  if (srcShape.empty() || dstShape.empty())
162  return {};
163 
164  if (ShapedType::isDynamicShape(srcShape) ||
165  ShapedType::isDynamicShape(dstShape)) {
166  assert(dstShape.size() == 1);
168  for (auto i : llvm::seq<int64_t>(srcShape.size()))
169  exprs.push_back(builder.getAffineDimExpr(i));
170  return {exprs};
171  }
172 
173  SmallVector<ReassociationExprs> reassociationMap(dstShape.size());
174  unsigned currSrcDim = 0, currDstDim = 0;
175  while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
176  int64_t dstSize = dstShape[currDstDim];
177  int64_t srcSize = srcShape[currSrcDim];
178  while (srcSize < dstSize && currSrcDim < srcShape.size()) {
179  reassociationMap[currDstDim].push_back(
180  builder.getAffineDimExpr(currSrcDim++));
181  srcSize *= srcShape[currSrcDim];
182  }
183  if (srcSize == dstSize) {
184  reassociationMap[currDstDim].push_back(
185  builder.getAffineDimExpr(currSrcDim++));
186  // If the next dim in collapsedShape is not 1, treat subsequent dims in
187  // expandedShape which are 1 to be collapsed.
188  if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) {
189  while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
190  reassociationMap[currDstDim].push_back(
191  builder.getAffineDimExpr(currSrcDim++));
192  }
193  }
194  }
195  currDstDim++;
196  }
197 
198  // If the source and target shapes are compatible, both iterators must have
199  // reached the end. This condition is guaranteed by the op verifier for
200  // static shapes.
201  assert(currSrcDim == srcShape.size() && currDstDim == dstShape.size());
202  return reassociationMap;
203 }
204 
205 // Create a tensor.collapse_shape op that reshapes the input into the given
206 // result type.
207 Value createCollapse(OpBuilder &builder, Location loc, TensorType resultType,
208  Value input) {
209  auto reassociationMap =
210  createReassociationMapForCollapse(builder, input.getType(), resultType);
211  return builder.createOrFold<tensor::CollapseShapeOp>(loc, resultType, input,
212  reassociationMap);
213 }
214 
215 // Create a tensor.expand_shape op that reshapes the input into the given result
216 // type.
217 Value createExpand(OpBuilder &builder, Location loc, TensorType resultType,
218  Value input) {
219  auto reassociationMap =
220  createReassociationMapForCollapse(builder, resultType, input.getType());
221  return builder.createOrFold<tensor::ExpandShapeOp>(loc, resultType, input,
222  reassociationMap);
223 }
224 
225 class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
226 public:
228 
229  LogicalResult
230  matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
231  ConversionPatternRewriter &rewriter) const final {
232  auto loc = reshape.getLoc();
233  auto resultType = cast_if_present<ShapedType>(
234  getTypeConverter()->convertType(reshape.getType()));
235  if (!resultType) {
236  return rewriter.notifyMatchFailure(reshape.getLoc(),
237  "could not convert result type");
238  }
239  auto input = dyn_cast<TypedValue<TensorType>>(adaptor.getInput1());
240  if (!input) {
241  return rewriter.notifyMatchFailure(reshape.getLoc(),
242  "expected input type to be tensor");
243  }
244 
246  if (!tosa::getConstShapeValues(reshape.getShape().getDefiningOp(),
247  newShape)) {
248  return failure();
249  }
250 
251  // Infer all intermediate types
252  auto inputType = inferReshapeInputType(input, newShape);
253  auto expandedType = inferReshapeExpandedType(inputType, newShape);
254  auto collapsedType = inferReshapeCollapsedType(inputType, expandedType);
255 
256  // Cast input if needed
257  auto castInput =
258  rewriter.createOrFold<tensor::CastOp>(loc, inputType, input);
259 
260  // Emit collaspe-expand pair
261  auto collapsed = createCollapse(rewriter, loc, collapsedType, castInput);
262  auto expanded = createExpand(rewriter, loc, expandedType, collapsed);
263 
264  // Cast to final result type if needed
265  auto result =
266  rewriter.createOrFold<tensor::CastOp>(loc, resultType, expanded);
267  rewriter.replaceOp(reshape, result);
268  return success();
269  }
270 };
271 
272 class SliceConverter : public OpConversionPattern<tosa::SliceOp> {
273 public:
275 
276  LogicalResult
277  matchAndRewrite(tosa::SliceOp sliceOp, OpAdaptor adaptor,
278  ConversionPatternRewriter &rewriter) const final {
279  Location loc = sliceOp.getLoc();
280  Value input = adaptor.getInput1();
281  ShapedType resultType = cast<ShapedType>(sliceOp.getType());
282  if (llvm::isa<UnrankedTensorType>(resultType))
283  return failure();
284 
285  ElementsAttr startElems;
286  ElementsAttr sizeElems;
287 
288  if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
289  return rewriter.notifyMatchFailure(
290  sliceOp, "start of slice must be a static ranked shape");
291 
292  if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
293  return rewriter.notifyMatchFailure(
294  sliceOp, "size of slice must be a static ranked shape");
295 
296  llvm::SmallVector<int64_t> sliceStarts =
297  llvm::to_vector(startElems.getValues<int64_t>());
298  llvm::SmallVector<int64_t> sliceSizes =
299  llvm::to_vector(sizeElems.getValues<int64_t>());
300 
301  SmallVector<int64_t> strides, sizes;
302  strides.resize(cast<ShapedType>(sliceOp.getType()).getRank(), 1);
303 
304  SmallVector<Value> dynSizes;
305  for (const auto &i : llvm::enumerate(sliceSizes)) {
306  int64_t size = i.value();
307  size_t index = i.index();
308  sizes.push_back(size == -1 ? ShapedType::kDynamic : size);
309  if (!ShapedType::isDynamic(sizes.back()))
310  continue;
311 
312  auto dim = rewriter.create<tensor::DimOp>(loc, input, index);
313  auto offset = rewriter.create<arith::ConstantOp>(
314  loc, rewriter.getIndexAttr(sliceStarts[index]));
315  dynSizes.push_back(rewriter.create<arith::SubIOp>(loc, dim, offset));
316  }
317 
318  auto newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
319  sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}), dynSizes,
320  ValueRange({}), rewriter.getDenseI64ArrayAttr(sliceStarts),
321  rewriter.getDenseI64ArrayAttr(sizes),
322  rewriter.getDenseI64ArrayAttr(strides));
323 
324  rewriter.replaceOp(sliceOp, newSliceOp.getResult());
325 
326  // Remove const_shape ops when it no longer has use point.
327  Operation *startConstShape = sliceOp.getStart().getDefiningOp();
328  if (startConstShape->getResult(0).hasOneUse())
329  rewriter.eraseOp(startConstShape);
330 
331  Operation *sizeConstShape = sliceOp.getSize().getDefiningOp();
332  if (sizeConstShape->getResult(0).hasOneUse())
333  rewriter.eraseOp(sizeConstShape);
334 
335  return success();
336  }
337 };
338 
339 class PadConverter : public OpConversionPattern<tosa::PadOp> {
340 public:
342 
343  LogicalResult
344  matchAndRewrite(tosa::PadOp padOp, OpAdaptor adaptor,
345  ConversionPatternRewriter &rewriter) const final {
346  auto loc = padOp.getLoc();
347  auto input = padOp.getInput1();
348 
349  ElementsAttr paddingElems;
350  if (!matchPattern(padOp.getPadding(), m_Constant(&paddingElems))) {
351  return rewriter.notifyMatchFailure(
352  padOp, "padding must be a static shape value");
353  }
354  llvm::SmallVector<int64_t> paddingVals;
355  for (auto idx : paddingElems.getValues<IntegerAttr>()) {
356  paddingVals.push_back(static_cast<int64_t>(idx.getInt()));
357  }
358 
359  ShapedType inputTy = cast<ShapedType>(input.getType());
360  int64_t rank = inputTy.getRank();
361 
362  // Setup the default constantAttr.
363 
364  Value padConstant = rewriter.createOrFold<tensor::ExtractOp>(
365  loc, padOp.getPadConst(), ValueRange({}));
366 
367  if (!padConstant) {
368  return rewriter.notifyMatchFailure(
369  padOp, "tosa.pad was unable to determine the pad constant value.");
370  }
371 
373  SmallVector<OpFoldResult, 3> highValues;
374 
375  lowValues.reserve(rank);
376  highValues.reserve(rank);
377 
378  for (int i = 0; i < rank; i++) {
379  Value lowVal = rewriter.create<arith::ConstantOp>(
380  loc, rewriter.getIndexAttr(paddingVals[2 * i]));
381  Value highVal = rewriter.create<arith::ConstantOp>(
382  loc, rewriter.getIndexAttr(paddingVals[2 * i + 1]));
383  lowValues.push_back(lowVal);
384  highValues.push_back(highVal);
385  }
386 
387  auto newPadOp = rewriter.create<tensor::PadOp>(
388  loc, padOp.getType(), input, lowValues, highValues, padConstant);
389 
390  rewriter.replaceOp(padOp, newPadOp.getResult());
391  return success();
392  }
393 };
394 
395 struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
397 
398  LogicalResult
399  matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor,
400  ConversionPatternRewriter &rewriter) const override {
401  auto resultType = dyn_cast<RankedTensorType>(op.getType());
402 
403  Location loc = op.getLoc();
404  int axis = op.getAxis();
405  Value axisValue =
406  rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(axis));
407  int64_t rank = resultType.getRank();
408 
409  SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
410  SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
412  tensor::getMixedSizes(rewriter, op.getLoc(), adaptor.getOperands()[0]);
413 
414  // Pre-compute the offsets along the axis dimension.
415  // The axisOffsets will be of size rank + 1, where the last value
416  // will hold the total size of the tensor along the 'axis' dimension.
417  SmallVector<OpFoldResult> axisOffsets;
418  axisOffsets.push_back(rewriter.getIndexAttr(0));
419  axisOffsets.push_back(sizes[axis]);
420 
421  for (auto arg : adaptor.getOperands().drop_front()) {
422  auto size = rewriter.createOrFold<tensor::DimOp>(loc, arg, axisValue);
423  auto currentOffset =
424  getValueOrCreateConstantIndexOp(rewriter, loc, axisOffsets.back());
425  auto total =
426  rewriter.createOrFold<arith::AddIOp>(loc, currentOffset, size);
427  axisOffsets.push_back(getAsOpFoldResult(total));
428  }
429  sizes[axis] = axisOffsets.back();
430 
431  // Compute the dynamic sizes of the tensor.empty operation.
432  // This is based off of the specified result type of the tosa.concat
433  // operation, since we don't want to change the result type of the operation
434  // during the conversion.
435  SmallVector<Value> dynDims;
436  for (int64_t i = 0; i < rank; ++i) {
437  if (resultType.isDynamicDim(i)) {
438  dynDims.push_back(
439  getValueOrCreateConstantIndexOp(rewriter, loc, sizes[i]));
440  }
441  }
442 
443  Value result = rewriter.create<tensor::EmptyOp>(
444  loc, resultType.getShape(), resultType.getElementType(), dynDims);
445 
446  for (auto [arg, offset] : llvm::zip(adaptor.getOperands(), axisOffsets)) {
447  auto sizes = tensor::getMixedSizes(rewriter, op.getLoc(), arg);
448  offsets[axis] = offset;
449  result = rewriter.createOrFold<tensor::InsertSliceOp>(
450  loc, arg, result, offsets, sizes, strides);
451  }
452  rewriter.replaceOp(op, result);
453  return success();
454  }
455 };
456 
457 } // namespace
458 
460  const TypeConverter &converter, RewritePatternSet *patterns) {
461  patterns
462  ->add<ConcatConverter, PadConverter, ReshapeConverter, SliceConverter>(
463  converter, patterns->getContext());
464 }
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:360
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
This class helps build Operations.
Definition: Builders.h:205
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:518
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:55
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
RankedTensorType clone(ArrayRef< int64_t > shape, Type elementType) const
Return a clone of this type with the given new shape and element type.
Type conversion class.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:215
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:70
void populateTosaToTensorConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369