MLIR  19.0.0git
TensorTransformOps.cpp
Go to the documentation of this file.
1 //===- TensorTransformOps.cpp - Implementation of tensor transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
18 #include "mlir/IR/Builders.h"
20 
21 using namespace mlir;
22 using namespace tensor;
23 
24 //===----------------------------------------------------------------------===//
25 // FindPayloadReplacementOpInterface implementations
26 //===----------------------------------------------------------------------===//
27 
28 namespace {
29 struct ExtractSliceOpReplacementInterface
30  : public transform::FindPayloadReplacementOpInterface::ExternalModel<
31  ExtractSliceOpReplacementInterface, tensor::ExtractSliceOp> {
32  SmallVector<Value> getNextOperands(Operation *op) const {
33  auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
34  if (!isCastLikeExtractSliceOp(extractSliceOp))
35  return {};
36  return {extractSliceOp.getSource()};
37  }
38 };
39 
40 struct InsertSliceOpReplacementInterface
41  : public transform::FindPayloadReplacementOpInterface::ExternalModel<
42  InsertSliceOpReplacementInterface, tensor::InsertSliceOp> {
43  SmallVector<Value> getNextOperands(Operation *op) const {
44  auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
45  if (!isCastLikeInsertSliceOp(insertSliceOp))
46  return {};
47  return {insertSliceOp.getSource()};
48  }
49 };
50 
51 struct ReshapeOpReplacementInterface
52  : public transform::FindPayloadReplacementOpInterface::ExternalModel<
53  ReshapeOpReplacementInterface, tensor::ReshapeOp> {
54  SmallVector<Value> getNextOperands(Operation *op) const {
55  auto reshapeOp = cast<tensor::ReshapeOp>(op);
56  return {reshapeOp.getSource()};
57  }
58 };
59 
60 template <typename ConcreteOp>
61 struct ReassociativeReshapeOpReplacementInterface
62  : public transform::FindPayloadReplacementOpInterface::ExternalModel<
63  ReassociativeReshapeOpReplacementInterface<ConcreteOp>, ConcreteOp> {
64  SmallVector<Value> getNextOperands(Operation *op) const {
65  auto reshapeOp = cast<ConcreteOp>(op);
66  return {reshapeOp.getSrc()};
67  }
68 };
69 } // namespace
70 
72  DialectRegistry &registry) {
73  registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
74  CollapseShapeOp::attachInterface<
75  ReassociativeReshapeOpReplacementInterface<CollapseShapeOp>>(*ctx);
76  ExpandShapeOp::attachInterface<
77  ReassociativeReshapeOpReplacementInterface<ExpandShapeOp>>(*ctx);
78  ExtractSliceOp::attachInterface<ExtractSliceOpReplacementInterface>(*ctx);
79  InsertSliceOp::attachInterface<InsertSliceOpReplacementInterface>(*ctx);
80  ReshapeOp::attachInterface<ReshapeOpReplacementInterface>(*ctx);
81  });
82 }
83 
84 //===----------------------------------------------------------------------===//
85 // Apply...PatternsOp
86 //===----------------------------------------------------------------------===//
87 
88 void transform::ApplyDecomposeTensorConcatPatternsOp::populatePatterns(
89  RewritePatternSet &patterns) {
91 }
92 
93 void transform::ApplyDropRedundantInsertSliceRankExpansionPatternsOp::
94  populatePatterns(RewritePatternSet &patterns) {
96 }
97 
98 void transform::ApplyFoldTensorEmptyPatternsOp::populatePatterns(
99  RewritePatternSet &patterns) {
100  tensor::populateFoldTensorEmptyPatterns(patterns, getFoldSingleUseOnly());
101 }
102 
103 void transform::ApplyFoldIntoPackAndUnpackPatternsOp::populatePatterns(
104  RewritePatternSet &patterns) {
106 }
107 
108 void transform::ApplyFoldTensorSubsetOpsPatternsOp::populatePatterns(
109  RewritePatternSet &patterns) {
111 }
112 
113 void transform::ApplyFoldTensorSubsetOpsIntoVectorTransfersPatternsOp::
114  populatePatterns(RewritePatternSet &patterns) {
116 }
117 
118 void transform::ApplyMergeConsecutiveInsertExtractSlicePatternsOp::
119  populatePatterns(RewritePatternSet &patterns) {
121 }
122 
123 void transform::ApplyReassociativeReshapeFoldingPatternsOp::populatePatterns(
124  RewritePatternSet &patterns) {
126 }
127 
128 void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns(
129  RewritePatternSet &patterns) {
130  ControlFoldFn defaultControlFn = [](OpOperand *fusedOperand) {
131  Operation *producer = fusedOperand->get().getDefiningOp();
132  return producer && producer->hasOneUse();
133  };
134 
135  ControlFoldFn aggressiveControlFn = [](OpOperand *fusedOperand) {
136  return true;
137  };
138 
139  // Add folding with reshape by expansion patterns.
140  if (getAggressive())
141  tensor::populateRewriteAsConstantPatterns(patterns, aggressiveControlFn);
142  else
143  tensor::populateRewriteAsConstantPatterns(patterns, defaultControlFn);
144 }
145 
146 //===----------------------------------------------------------------------===//
147 // TypeConversionCastTensorShapeOp
148 //===----------------------------------------------------------------------===//
149 
150 void transform::TypeConversionCastShapeDynamicDimsOp::
151  populateTypeMaterializations(TypeConverter &converter) {
152  bool ignoreDynamicInfo = getIgnoreDynamicInfo();
153  converter.addSourceMaterialization([ignoreDynamicInfo](
154  OpBuilder &builder, Type resultType,
155  ValueRange inputs,
156  Location loc) -> std::optional<Value> {
157  if (inputs.size() != 1) {
158  return std::nullopt;
159  }
160  Value input = inputs[0];
161  if (!ignoreDynamicInfo &&
162  !tensor::preservesStaticInformation(resultType, input.getType())) {
163  return std::nullopt;
164  }
165  if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) {
166  return std::nullopt;
167  }
168  return builder.create<tensor::CastOp>(loc, resultType, input).getResult();
169  });
170  converter.addTargetMaterialization([](OpBuilder &builder, Type resultType,
171  ValueRange inputs,
172  Location loc) -> std::optional<Value> {
173  if (inputs.size() != 1) {
174  return std::nullopt;
175  }
176  Value input = inputs[0];
177  if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) {
178  return std::nullopt;
179  }
180  return builder.create<tensor::CastOp>(loc, resultType, input).getResult();
181  });
182 }
183 
184 //===----------------------------------------------------------------------===//
185 // MakeLoopIndependentOp
186 //===----------------------------------------------------------------------===//
187 
188 DiagnosedSilenceableFailure transform::MakeLoopIndependentOp::applyToOne(
189  transform::TransformRewriter &rewriter, Operation *target,
191  transform::TransformState &state) {
192  // Gather IVs.
193  SmallVector<Value> ivs;
194  Operation *nextOp = target;
195  for (uint64_t i = 0, e = getNumLoops(); i < e; ++i) {
196  nextOp = nextOp->getParentOfType<scf::ForOp>();
197  if (!nextOp) {
198  DiagnosedSilenceableFailure diag = emitSilenceableError()
199  << "could not find " << i
200  << "-th enclosing loop";
201  diag.attachNote(target->getLoc()) << "target op";
202  return diag;
203  }
204  ivs.push_back(cast<scf::ForOp>(nextOp).getInductionVar());
205  }
206 
207  // Rewrite IR.
208  FailureOr<Value> replacement = failure();
209  if (auto padOp = dyn_cast<tensor::PadOp>(target)) {
210  replacement = tensor::buildIndependentOp(rewriter, padOp, ivs);
211  } else if (auto emptyOp = dyn_cast<tensor::EmptyOp>(target)) {
212  replacement = tensor::buildIndependentOp(rewriter, emptyOp, ivs);
213  } else {
214  DiagnosedSilenceableFailure diag = emitSilenceableError()
215  << "unsupported target op";
216  diag.attachNote(target->getLoc()) << "target op";
217  return diag;
218  }
219  if (failed(replacement)) {
221  emitSilenceableError() << "could not make target op loop-independent";
222  diag.attachNote(target->getLoc()) << "target op";
223  return diag;
224  }
225  rewriter.replaceOp(target, *replacement);
226  results.push_back(replacement->getDefiningOp());
228 }
229 
230 //===----------------------------------------------------------------------===//
231 // Transform op registration
232 //===----------------------------------------------------------------------===//
233 
234 namespace {
235 class TensorTransformDialectExtension
237  TensorTransformDialectExtension> {
238 public:
239  using Base::Base;
240 
241  void init() {
242  declareGeneratedDialect<affine::AffineDialect>();
243  declareGeneratedDialect<tensor::TensorDialect>();
244 
245  registerTransformOps<
246 #define GET_OP_LIST
247 #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.cpp.inc"
248  >();
249  }
250 };
251 } // namespace
252 
253 #define GET_OP_CLASSES
254 #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.cpp.inc"
255 
257  DialectRegistry &registry) {
258  registry.addExtensions<TensorTransformDialectExtension>();
259 }
static std::string diag(const llvm::Value &value)
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
void addExtensions()
Add the given extensions to the registry.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents an operand of an operation.
Definition: Value.h:267
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:845
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Type conversion class.
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a legal type to an illega...
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting type from an illegal,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void push_back(Operation *op)
Appends an element to the list.
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...
void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns, bool foldSingleUseOnly=false)
Populates patterns with patterns that fold tensor.empty with its consumers.
void registerTransformDialectExtension(DialectRegistry &registry)
void populateMergeConsecutiveInsertExtractSlicePatterns(RewritePatternSet &patterns)
Collects patterns to merge consecutive tensor.insert_slice/extract_slice into one.
bool isCastLikeInsertSliceOp(InsertSliceOp op)
A tensor.insert_slice is a cast-like operation if it merely rank-extends the source tensor or inserts...
Definition: Utils.cpp:131
void populateDecomposeTensorConcatPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that decompose tensor.concat into tensor.empty of a tensor of the co...
void populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into consumer load/store ops into patterns.
void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold tensor.expand_shape and tensor.collapse_shape into other o...
void registerFindPayloadReplacementOpInterfaceExternalModels(DialectRegistry &registry)
bool isCastLikeExtractSliceOp(ExtractSliceOp op)
A tensor.extract_slice is a cast-like operation if it merely rank-reduces unit dimensions of the sour...
Definition: Utils.cpp:155
void populateDropRedundantInsertSliceRankExpansionPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that drop redundant tensor.insert_slice rank expansions.
FailureOr< Value > buildIndependentOp(OpBuilder &b, tensor::PadOp padOp, ValueRange independencies)
Build a new tensor::PadOp with low/high padding that is independent of all given independencies.
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Definition: TensorOps.cpp:267
std::function< bool(OpOperand *)> ControlFoldFn
Definition: Transforms.h:94
void populateRewriteAsConstantPatterns(RewritePatternSet &patterns, const ControlFoldFn &controlFn)
Populates patterns with patterns that replace tensor ops (such as tensor.generate) with constants whe...
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72