MLIR  18.0.0git
TensorInferTypeOpInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- InferTypeOpImpl.cpp - InferType Interface external models *- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
15 
16 using namespace mlir;
17 using namespace mlir::tensor;
18 
19 /// Compute a map that for a given dimension of the expanded type gives the
20 /// dimension in the collapsed type it maps to. Essentially its the inverse of
21 /// the `reassocation` maps.
24  llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim;
25  for (const auto &map : enumerate(reassociation)) {
26  unsigned startPos =
27  map.value().getResults().front().cast<AffineDimExpr>().getPosition();
28  unsigned endPos =
29  map.value().getResults().back().cast<AffineDimExpr>().getPosition();
30  for (auto dim : llvm::seq_inclusive(startPos, endPos)) {
31  expandedDimToCollapsedDim[dim] = map.index();
32  }
33  }
34  return expandedDimToCollapsedDim;
35 }
36 
37 /// For reshape op compute the shape at dimension `dimIndex` of the output in
38 /// terms of shape of the `src`, when the reshape op is a collapsing
39 /// operation. It is the product of the shape of the collapsed dimensions of the
40 /// `src`.
42  OpBuilder &builder, Location loc, int64_t dimIndex, Value src,
43  ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociationMap) {
44  if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) {
45  // Static dimension: return Attribute.
46  return builder.getIndexAttr(dstStaticShape[dimIndex]);
47  }
48  AffineMap map = reassociationMap[dimIndex];
49  unsigned startPos =
50  map.getResults().front().cast<AffineDimExpr>().getPosition();
51  unsigned endPos = map.getResults().back().cast<AffineDimExpr>().getPosition();
52  AffineExpr expr;
53  SmallVector<OpFoldResult> dynamicDims;
54  for (auto dim : llvm::seq_inclusive(startPos, endPos)) {
55  dynamicDims.push_back(builder.createOrFold<tensor::DimOp>(loc, src, dim));
56  AffineExpr currExpr = builder.getAffineSymbolExpr(dim - startPos);
57  expr = (expr ? expr * currExpr : currExpr);
58  }
59 
60  // Dynamic dimension: return Value.
62  builder, loc, AffineMap::get(0, endPos - startPos + 1, expr),
63  dynamicDims)
64  ->getResult(0);
65 }
66 
67 /// Given the `src` of a collapsing reshape op and its reassociation maps,
68 /// compute the shape of the result of the reshape.
70  OpBuilder &builder, Location loc, Value src,
71  ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation) {
72  return llvm::to_vector<4>(llvm::map_range(
73  llvm::seq<int64_t>(0, dstStaticShape.size()), [&](int64_t dim) {
74  return getCollapsedOutputDimFromInputShape(
75  builder, loc, dim, src, dstStaticShape, reassociation);
76  }));
77 }
78 
79 /// For an expanding reshape op, compute the value for a dimension of the output
80 /// from the shape of the input.
82  OpBuilder &builder, Location loc, int64_t dimIndex, Value src,
83  ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation,
84  llvm::DenseMap<int64_t, int64_t> &expandedDimToCollapsedDim) {
85  if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) {
86  // Static dimension: return Attribute.
87  return builder.getIndexAttr(dstStaticShape[dimIndex]);
88  }
89  unsigned sourceDimPos = expandedDimToCollapsedDim[dimIndex];
90  unsigned startPos = reassociation[sourceDimPos]
91  .getResults()
92  .front()
93  .cast<AffineDimExpr>()
94  .getPosition();
95  unsigned endPos = reassociation[sourceDimPos]
96  .getResults()
97  .back()
98  .cast<AffineDimExpr>()
99  .getPosition();
100  int64_t linearizedStaticDim = 1;
101  for (auto d :
102  llvm::enumerate(dstStaticShape.slice(startPos, endPos - startPos + 1))) {
103  if (d.index() + startPos == static_cast<unsigned>(dimIndex))
104  continue;
105  assert(!ShapedType::isDynamic(d.value()) &&
106  "single dimension cannot be expanded into multiple dynamic "
107  "dimensions");
108  linearizedStaticDim *= d.value();
109  }
110  OpFoldResult sourceDim =
111  builder.create<tensor::DimOp>(loc, src, sourceDimPos).getResult();
112 
113  // Dynamic dimension: return Value.
115  builder, loc,
117  0, 1,
118  builder.getAffineSymbolExpr(0).floorDiv(linearizedStaticDim)),
119  sourceDim)
120  ->getResult(0);
121 }
122 
123 /// Given the `src` of an expanding reshape op, the reassociation maps and the
124 /// result type, compute the shape of the result of the reshape.
126  OpBuilder &builder, Location loc, Value src,
127  ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation) {
128  llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim =
129  getExpandedDimToCollapsedDimMap(reassociation);
130  return llvm::to_vector<4>(llvm::map_range(
131  llvm::seq<int64_t>(0, dstStaticShape.size()), [&](int64_t dim) {
132  return getExpandedOutputDimFromInputShape(builder, loc, dim, src,
133  dstStaticShape, reassociation,
134  expandedDimToCollapsedDim);
135  }));
136 }
137 
140  ArrayRef<int64_t> dstStaticShape,
141  ArrayRef<AffineMap> reassocation) {
142  return dstStaticShape.size() >
143  static_cast<size_t>(
144  llvm::cast<ShapedType>(src.getType()).getRank())
146  builder, loc, src, dstStaticShape, reassocation)
148  builder, loc, src, dstStaticShape, reassocation);
149 }
150 
151 template <typename OpTy>
153  : public ReifyRankedShapedTypeOpInterface::ExternalModel<
154  ReifyExpandOrCollapseShapeOp<OpTy>, OpTy> {
157  ReifiedRankedShapedTypeDims &reifiedReturnShapes) const {
158  auto loc = op->getLoc();
159  auto reshapeOp = cast<OpTy>(op);
160  reifiedReturnShapes.push_back(getReshapeOutputShapeFromInputShape(
161  b, loc, reshapeOp.getSrc(), reshapeOp.getResultType().getShape(),
162  reshapeOp.getReassociationMaps()));
163  return success();
164  }
165 };
166 
167 namespace {
168 
169 struct ReifyPadOp
170  : public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyPadOp,
171  PadOp> {
174  ReifiedRankedShapedTypeDims &reifiedReturnShapes) const {
175  auto padOp = cast<PadOp>(op);
176  Location loc = padOp.getLoc();
177  auto lowPad = padOp.getMixedLowPad();
178  auto highPad = padOp.getMixedHighPad();
180  for (auto dim : llvm::seq<int64_t>(0, padOp.getSourceType().getRank())) {
181  if (!padOp.getResultType().isDynamicDim(dim)) {
182  shapes.push_back(b.getIndexAttr(padOp.getResultType().getDimSize(dim)));
183  continue;
184  }
185 
186  // Shape along each dimension is source dim + low pad + high pad.
187  SmallVector<OpFoldResult> mapOperands;
188  mapOperands.push_back(
189  b.createOrFold<tensor::DimOp>(loc, padOp.getSource(), dim));
190  mapOperands.push_back(lowPad[dim]);
191  mapOperands.push_back(highPad[dim]);
192  AffineExpr expr = b.getAffineDimExpr(0) + b.getAffineSymbolExpr(0) +
193  b.getAffineSymbolExpr(1);
194  shapes.push_back(getValueOrCreateConstantIndexOp(
195  b, loc,
197  b, loc, AffineMap::get(1, 2, expr), mapOperands)));
198  }
199  reifiedReturnShapes.emplace_back(std::move(shapes));
200  return success();
201  }
202 };
203 
204 } // namespace
205 
207  DialectRegistry &registry) {
208  registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
209  ExpandShapeOp::attachInterface<
211  CollapseShapeOp::attachInterface<
213  PadOp::attachInterface<ReifyPadOp>(*ctx);
214  });
215 }
static OpFoldResult getCollapsedOutputDimFromInputShape(OpBuilder &builder, Location loc, int64_t dimIndex, Value src, ArrayRef< int64_t > dstStaticShape, ArrayRef< AffineMap > reassociationMap)
For reshape op compute the shape at dimension dimIndex of the output in terms of shape of the src,...
static llvm::DenseMap< int64_t, int64_t > getExpandedDimToCollapsedDimMap(ArrayRef< AffineMap > reassociation)
Compute a map that for a given dimension of the expanded type gives the dimension in the collapsed ty...
static OpFoldResult getExpandedOutputDimFromInputShape(OpBuilder &builder, Location loc, int64_t dimIndex, Value src, ArrayRef< int64_t > dstStaticShape, ArrayRef< AffineMap > reassociation, llvm::DenseMap< int64_t, int64_t > &expandedDimToCollapsedDim)
For an expanding reshape op, compute the value for a dimension of the output from the shape of the in...
static SmallVector< OpFoldResult, 4 > getReshapeOutputShapeFromInputShape(OpBuilder &builder, Location loc, Value src, ArrayRef< int64_t > dstStaticShape, ArrayRef< AffineMap > reassocation)
static SmallVector< OpFoldResult, 4 > getExpandedOutputShapeFromInputShape(OpBuilder &builder, Location loc, Value src, ArrayRef< int64_t > dstStaticShape, ArrayRef< AffineMap > reassociation)
Given the src of an expanding reshape op, the reassociation maps and the result type,...
static SmallVector< OpFoldResult, 4 > getCollapsedOutputShapeFromInputShape(OpBuilder &builder, Location loc, Value src, ArrayRef< int64_t > dstStaticShape, ArrayRef< AffineMap > reassociation)
Given the src of a collapsing reshape op and its reassociation maps, compute the shape of the result ...
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:786
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:44
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:350
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:357
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:353
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:206
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:505
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
This class represents a single result from folding an operation.
Definition: OpDefinition.h:266
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1229
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1276
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
void registerInferTypeOpInterfaceExternalModels(mlir::DialectRegistry &registry)
Registers external models for Infer Type interfaces for tensor ops.
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:40
LogicalResult reifyResultShapes(Operation *op, OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) const
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26