MLIR  20.0.0git
TensorInferTypeOpInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- InferTypeOpImpl.cpp - InferType Interface external models *- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
15 
16 using namespace mlir;
17 using namespace mlir::tensor;
18 
19 /// Compute a map that for a given dimension of the expanded type gives the
20 /// dimension in the collapsed type it maps to. Essentially its the inverse of
21 /// the `reassocation` maps.
24  llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim;
25  for (const auto &map : enumerate(reassociation)) {
26  unsigned startPos =
27  cast<AffineDimExpr>(map.value().getResults().front()).getPosition();
28  unsigned endPos =
29  cast<AffineDimExpr>(map.value().getResults().back()).getPosition();
30  for (auto dim : llvm::seq_inclusive(startPos, endPos)) {
31  expandedDimToCollapsedDim[dim] = map.index();
32  }
33  }
34  return expandedDimToCollapsedDim;
35 }
36 
37 /// For reshape op compute the shape at dimension `dimIndex` of the output in
38 /// terms of shape of the `src`, when the reshape op is a collapsing
39 /// operation. It is the product of the shape of the collapsed dimensions of the
40 /// `src`.
42  OpBuilder &builder, Location loc, int64_t dimIndex, Value src,
43  ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociationMap) {
44  if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) {
45  // Static dimension: return Attribute.
46  return builder.getIndexAttr(dstStaticShape[dimIndex]);
47  }
48  AffineMap map = reassociationMap[dimIndex];
49  unsigned startPos =
50  cast<AffineDimExpr>(map.getResults().front()).getPosition();
51  unsigned endPos = cast<AffineDimExpr>(map.getResults().back()).getPosition();
52  AffineExpr expr;
53  SmallVector<OpFoldResult> dynamicDims;
54  for (auto dim : llvm::seq_inclusive(startPos, endPos)) {
55  dynamicDims.push_back(builder.createOrFold<tensor::DimOp>(loc, src, dim));
56  AffineExpr currExpr = builder.getAffineSymbolExpr(dim - startPos);
57  expr = (expr ? expr * currExpr : currExpr);
58  }
59 
60  // Dynamic dimension: return Value.
62  builder, loc, AffineMap::get(0, endPos - startPos + 1, expr),
63  dynamicDims)
64  ->getResult(0);
65 }
66 
67 /// Given the `src` of a collapsing reshape op and its reassociation maps,
68 /// compute the shape of the result of the reshape.
70  OpBuilder &builder, Location loc, Value src,
71  ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation) {
72  return llvm::to_vector<4>(llvm::map_range(
73  llvm::seq<int64_t>(0, dstStaticShape.size()), [&](int64_t dim) {
74  return getCollapsedOutputDimFromInputShape(
75  builder, loc, dim, src, dstStaticShape, reassociation);
76  }));
77 }
78 
79 /// For an expanding reshape op, compute the value for a dimension of the output
80 /// from the shape of the input.
82  OpBuilder &builder, Location loc, int64_t dimIndex, Value src,
83  ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation,
84  llvm::DenseMap<int64_t, int64_t> &expandedDimToCollapsedDim) {
85  if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) {
86  // Static dimension: return Attribute.
87  return builder.getIndexAttr(dstStaticShape[dimIndex]);
88  }
89  unsigned sourceDimPos = expandedDimToCollapsedDim[dimIndex];
90  unsigned startPos =
91  cast<AffineDimExpr>(reassociation[sourceDimPos].getResults().front())
92  .getPosition();
93  unsigned endPos =
94  cast<AffineDimExpr>(reassociation[sourceDimPos].getResults().back())
95  .getPosition();
96  int64_t linearizedStaticDim = 1;
97  for (auto d :
98  llvm::enumerate(dstStaticShape.slice(startPos, endPos - startPos + 1))) {
99  if (d.index() + startPos == static_cast<unsigned>(dimIndex))
100  continue;
101  assert(!ShapedType::isDynamic(d.value()) &&
102  "single dimension cannot be expanded into multiple dynamic "
103  "dimensions");
104  linearizedStaticDim *= d.value();
105  }
106  OpFoldResult sourceDim =
107  builder.create<tensor::DimOp>(loc, src, sourceDimPos).getResult();
108 
109  // Dynamic dimension: return Value.
111  builder, loc,
113  0, 1,
114  builder.getAffineSymbolExpr(0).floorDiv(linearizedStaticDim)),
115  sourceDim)
116  ->getResult(0);
117 }
118 
119 /// Given the `src` of an expanding reshape op, the reassociation maps and the
120 /// result type, compute the shape of the result of the reshape.
122  OpBuilder &builder, Location loc, Value src,
123  ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation) {
124  llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim =
125  getExpandedDimToCollapsedDimMap(reassociation);
126  return llvm::to_vector<4>(llvm::map_range(
127  llvm::seq<int64_t>(0, dstStaticShape.size()), [&](int64_t dim) {
128  return getExpandedOutputDimFromInputShape(builder, loc, dim, src,
129  dstStaticShape, reassociation,
130  expandedDimToCollapsedDim);
131  }));
132 }
133 
136  ArrayRef<int64_t> dstStaticShape,
137  ArrayRef<AffineMap> reassocation) {
138  return dstStaticShape.size() >
139  static_cast<size_t>(
140  llvm::cast<ShapedType>(src.getType()).getRank())
142  builder, loc, src, dstStaticShape, reassocation)
144  builder, loc, src, dstStaticShape, reassocation);
145 }
146 
147 template <typename OpTy>
149  : public ReifyRankedShapedTypeOpInterface::ExternalModel<
150  ReifyExpandOrCollapseShapeOp<OpTy>, OpTy> {
151  LogicalResult
153  ReifiedRankedShapedTypeDims &reifiedReturnShapes) const {
154  auto loc = op->getLoc();
155  auto reshapeOp = cast<OpTy>(op);
156  reifiedReturnShapes.push_back(getReshapeOutputShapeFromInputShape(
157  b, loc, reshapeOp.getSrc(), reshapeOp.getResultType().getShape(),
158  reshapeOp.getReassociationMaps()));
159  return success();
160  }
161 };
162 
163 namespace {
164 
165 struct ReifyPadOp
166  : public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyPadOp,
167  PadOp> {
168  LogicalResult
170  ReifiedRankedShapedTypeDims &reifiedReturnShapes) const {
171  auto padOp = cast<PadOp>(op);
172  Location loc = padOp.getLoc();
173  auto lowPad = padOp.getMixedLowPad();
174  auto highPad = padOp.getMixedHighPad();
176  for (auto dim : llvm::seq<int64_t>(0, padOp.getSourceType().getRank())) {
177  if (!padOp.getResultType().isDynamicDim(dim)) {
178  shapes.push_back(b.getIndexAttr(padOp.getResultType().getDimSize(dim)));
179  continue;
180  }
181 
182  // Shape along each dimension is source dim + low pad + high pad.
183  SmallVector<OpFoldResult> mapOperands;
184  mapOperands.push_back(
185  b.createOrFold<tensor::DimOp>(loc, padOp.getSource(), dim));
186  mapOperands.push_back(lowPad[dim]);
187  mapOperands.push_back(highPad[dim]);
188  AffineExpr expr = b.getAffineDimExpr(0) + b.getAffineSymbolExpr(0) +
189  b.getAffineSymbolExpr(1);
190  shapes.push_back(getValueOrCreateConstantIndexOp(
191  b, loc,
193  b, loc, AffineMap::get(1, 2, expr), mapOperands)));
194  }
195  reifiedReturnShapes.emplace_back(std::move(shapes));
196  return success();
197  }
198 };
199 
200 } // namespace
201 
203  DialectRegistry &registry) {
204  registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
205  ExpandShapeOp::attachInterface<
207  CollapseShapeOp::attachInterface<
209  PadOp::attachInterface<ReifyPadOp>(*ctx);
210  });
211 }
static OpFoldResult getCollapsedOutputDimFromInputShape(OpBuilder &builder, Location loc, int64_t dimIndex, Value src, ArrayRef< int64_t > dstStaticShape, ArrayRef< AffineMap > reassociationMap)
For reshape op compute the shape at dimension dimIndex of the output in terms of shape of the src,...
static llvm::DenseMap< int64_t, int64_t > getExpandedDimToCollapsedDimMap(ArrayRef< AffineMap > reassociation)
Compute a map that for a given dimension of the expanded type gives the dimension in the collapsed ty...
static OpFoldResult getExpandedOutputDimFromInputShape(OpBuilder &builder, Location loc, int64_t dimIndex, Value src, ArrayRef< int64_t > dstStaticShape, ArrayRef< AffineMap > reassociation, llvm::DenseMap< int64_t, int64_t > &expandedDimToCollapsedDim)
For an expanding reshape op, compute the value for a dimension of the output from the shape of the in...
static SmallVector< OpFoldResult, 4 > getReshapeOutputShapeFromInputShape(OpBuilder &builder, Location loc, Value src, ArrayRef< int64_t > dstStaticShape, ArrayRef< AffineMap > reassocation)
static SmallVector< OpFoldResult, 4 > getExpandedOutputShapeFromInputShape(OpBuilder &builder, Location loc, Value src, ArrayRef< int64_t > dstStaticShape, ArrayRef< AffineMap > reassociation)
Given the src of an expanding reshape op, the reassociation maps and the result type,...
static SmallVector< OpFoldResult, 4 > getCollapsedOutputShapeFromInputShape(OpBuilder &builder, Location loc, Value src, ArrayRef< int64_t > dstStaticShape, ArrayRef< AffineMap > reassociation)
Given the src of a collapsing reshape op and its reassociation maps, compute the shape of the result ...
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:907
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:136
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:387
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:383
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:212
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:525
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:476
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1142
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1192
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void registerInferTypeOpInterfaceExternalModels(mlir::DialectRegistry &registry)
Registers external models for Infer Type interfaces for tensor ops.
Include the generated interface declarations.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
LogicalResult reifyResultShapes(Operation *op, OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) const