MLIR  19.0.0git
TensorTransformOps.cpp
Go to the documentation of this file.
1 //===- TensorTransformOps.cpp - Implementation of tensor transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
18 #include "mlir/IR/Builders.h"
20 
21 using namespace mlir;
22 using namespace tensor;
23 
24 //===----------------------------------------------------------------------===//
25 // FindPayloadReplacementOpInterface implementations
26 //===----------------------------------------------------------------------===//
27 
28 namespace {
29 struct ExtractSliceOpReplacementInterface
30  : public transform::FindPayloadReplacementOpInterface::ExternalModel<
31  ExtractSliceOpReplacementInterface, tensor::ExtractSliceOp> {
32  SmallVector<Value> getNextOperands(Operation *op) const {
33  auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
34  if (!isCastLikeExtractSliceOp(extractSliceOp))
35  return {};
36  return {extractSliceOp.getSource()};
37  }
38 };
39 
40 struct InsertSliceOpReplacementInterface
41  : public transform::FindPayloadReplacementOpInterface::ExternalModel<
42  InsertSliceOpReplacementInterface, tensor::InsertSliceOp> {
43  SmallVector<Value> getNextOperands(Operation *op) const {
44  auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
45  if (!isCastLikeInsertSliceOp(insertSliceOp))
46  return {};
47  return {insertSliceOp.getSource()};
48  }
49 };
50 
51 struct ReshapeOpReplacementInterface
52  : public transform::FindPayloadReplacementOpInterface::ExternalModel<
53  ReshapeOpReplacementInterface, tensor::ReshapeOp> {
54  SmallVector<Value> getNextOperands(Operation *op) const {
55  auto reshapeOp = cast<tensor::ReshapeOp>(op);
56  return {reshapeOp.getSource()};
57  }
58 };
59 
60 template <typename ConcreteOp>
61 struct ReassociativeReshapeOpReplacementInterface
62  : public transform::FindPayloadReplacementOpInterface::ExternalModel<
63  ReassociativeReshapeOpReplacementInterface<ConcreteOp>, ConcreteOp> {
64  SmallVector<Value> getNextOperands(Operation *op) const {
65  auto reshapeOp = cast<ConcreteOp>(op);
66  return {reshapeOp.getSrc()};
67  }
68 };
69 } // namespace
70 
72  DialectRegistry &registry) {
73  registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
74  CollapseShapeOp::attachInterface<
75  ReassociativeReshapeOpReplacementInterface<CollapseShapeOp>>(*ctx);
76  ExpandShapeOp::attachInterface<
77  ReassociativeReshapeOpReplacementInterface<ExpandShapeOp>>(*ctx);
78  ExtractSliceOp::attachInterface<ExtractSliceOpReplacementInterface>(*ctx);
79  InsertSliceOp::attachInterface<InsertSliceOpReplacementInterface>(*ctx);
80  ReshapeOp::attachInterface<ReshapeOpReplacementInterface>(*ctx);
81  });
82 }
83 
84 //===----------------------------------------------------------------------===//
85 // Apply...PatternsOp
86 //===----------------------------------------------------------------------===//
87 
88 void transform::ApplyDecomposeTensorConcatPatternsOp::populatePatterns(
89  RewritePatternSet &patterns) {
91 }
92 
93 void transform::ApplyDropRedundantInsertSliceRankExpansionPatternsOp::
94  populatePatterns(RewritePatternSet &patterns) {
96 }
97 
98 void transform::ApplyFoldTensorEmptyPatternsOp::populatePatterns(
99  RewritePatternSet &patterns) {
100  tensor::populateFoldTensorEmptyPatterns(patterns, getFoldSingleUseOnly());
101 }
102 
103 void transform::ApplyFoldIntoPackAndUnpackPatternsOp::populatePatterns(
104  RewritePatternSet &patterns) {
106 }
107 
108 void transform::ApplyFoldTensorSubsetOpsPatternsOp::populatePatterns(
109  RewritePatternSet &patterns) {
111 }
112 
113 void transform::ApplyFoldTensorSubsetOpsIntoVectorTransfersPatternsOp::
114  populatePatterns(RewritePatternSet &patterns) {
116 }
117 
118 void transform::ApplyMergeConsecutiveInsertExtractSlicePatternsOp::
119  populatePatterns(RewritePatternSet &patterns) {
121 }
122 
123 void transform::ApplyReassociativeReshapeFoldingPatternsOp::populatePatterns(
124  RewritePatternSet &patterns) {
126 }
127 
128 void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns(
129  RewritePatternSet &patterns) {
131 }
132 
133 //===----------------------------------------------------------------------===//
134 // TypeConversionCastTensorShapeOp
135 //===----------------------------------------------------------------------===//
136 
137 void transform::TypeConversionCastShapeDynamicDimsOp::
138  populateTypeMaterializations(TypeConverter &converter) {
139  bool ignoreDynamicInfo = getIgnoreDynamicInfo();
140  converter.addSourceMaterialization([ignoreDynamicInfo](
141  OpBuilder &builder, Type resultType,
142  ValueRange inputs,
143  Location loc) -> std::optional<Value> {
144  if (inputs.size() != 1) {
145  return std::nullopt;
146  }
147  Value input = inputs[0];
148  if (!ignoreDynamicInfo &&
149  !tensor::preservesStaticInformation(resultType, input.getType())) {
150  return std::nullopt;
151  }
152  if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) {
153  return std::nullopt;
154  }
155  return builder.create<tensor::CastOp>(loc, resultType, input).getResult();
156  });
157  converter.addTargetMaterialization([](OpBuilder &builder, Type resultType,
158  ValueRange inputs,
159  Location loc) -> std::optional<Value> {
160  if (inputs.size() != 1) {
161  return std::nullopt;
162  }
163  Value input = inputs[0];
164  if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) {
165  return std::nullopt;
166  }
167  return builder.create<tensor::CastOp>(loc, resultType, input).getResult();
168  });
169 }
170 
171 //===----------------------------------------------------------------------===//
172 // MakeLoopIndependentOp
173 //===----------------------------------------------------------------------===//
174 
175 DiagnosedSilenceableFailure transform::MakeLoopIndependentOp::applyToOne(
176  transform::TransformRewriter &rewriter, Operation *target,
178  transform::TransformState &state) {
179  // Gather IVs.
180  SmallVector<Value> ivs;
181  Operation *nextOp = target;
182  for (uint64_t i = 0, e = getNumLoops(); i < e; ++i) {
183  nextOp = nextOp->getParentOfType<scf::ForOp>();
184  if (!nextOp) {
185  DiagnosedSilenceableFailure diag = emitSilenceableError()
186  << "could not find " << i
187  << "-th enclosing loop";
188  diag.attachNote(target->getLoc()) << "target op";
189  return diag;
190  }
191  ivs.push_back(cast<scf::ForOp>(nextOp).getInductionVar());
192  }
193 
194  // Rewrite IR.
195  FailureOr<Value> replacement = failure();
196  if (auto padOp = dyn_cast<tensor::PadOp>(target)) {
197  replacement = tensor::buildIndependentOp(rewriter, padOp, ivs);
198  } else if (auto emptyOp = dyn_cast<tensor::EmptyOp>(target)) {
199  replacement = tensor::buildIndependentOp(rewriter, emptyOp, ivs);
200  } else {
201  DiagnosedSilenceableFailure diag = emitSilenceableError()
202  << "unsupported target op";
203  diag.attachNote(target->getLoc()) << "target op";
204  return diag;
205  }
206  if (failed(replacement)) {
208  emitSilenceableError() << "could not make target op loop-independent";
209  diag.attachNote(target->getLoc()) << "target op";
210  return diag;
211  }
212  rewriter.replaceOp(target, *replacement);
213  results.push_back(replacement->getDefiningOp());
215 }
216 
217 //===----------------------------------------------------------------------===//
218 // Transform op registration
219 //===----------------------------------------------------------------------===//
220 
221 namespace {
222 class TensorTransformDialectExtension
224  TensorTransformDialectExtension> {
225 public:
226  using Base::Base;
227 
228  void init() {
229  declareGeneratedDialect<affine::AffineDialect>();
230  declareGeneratedDialect<tensor::TensorDialect>();
231 
232  registerTransformOps<
233 #define GET_OP_LIST
234 #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.cpp.inc"
235  >();
236  }
237 };
238 } // namespace
239 
240 #define GET_OP_CLASSES
241 #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.cpp.inc"
242 
244  DialectRegistry &registry) {
245  registry.addExtensions<TensorTransformDialectExtension>();
246 }
static std::string diag(const llvm::Value &value)
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
void addExtensions()
Add the given extensions to the registry.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Type conversion class.
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a legal type to an illega...
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting type from an illegal,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void push_back(Operation *op)
Appends an element to the list.
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...
void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns, bool foldSingleUseOnly=false)
Populates patterns with patterns that fold tensor.empty with tensor.
void registerTransformDialectExtension(DialectRegistry &registry)
void populateMergeConsecutiveInsertExtractSlicePatterns(RewritePatternSet &patterns)
Collects patterns to merge consecutive tensor.insert_slice/extract_slice into one.
bool isCastLikeInsertSliceOp(InsertSliceOp op)
A tensor.insert_slice is a cast-like operation if it merely rank-extends the source tensor or inserts...
Definition: Utils.cpp:142
void populateDecomposeTensorConcatPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that decompose tensor.concat into tensor.empty of a tensor of the co...
void populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into consumer load/store ops into patterns.
void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold tensor.expand_shape and tensor.collapse_shape into other o...
void registerFindPayloadReplacementOpInterfaceExternalModels(DialectRegistry &registry)
bool isCastLikeExtractSliceOp(ExtractSliceOp op)
A tensor.extract_slice is a cast-like operation if it merely rank-reduces unit dimensions of the sour...
Definition: Utils.cpp:166
void populateDropRedundantInsertSliceRankExpansionPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that drop redundant tensor.insert_slice rank expansions.
FailureOr< Value > buildIndependentOp(OpBuilder &b, tensor::PadOp padOp, ValueRange independencies)
Build a new tensor::PadOp with low/high padding that is independent of all given independencies.
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Definition: TensorOps.cpp:263
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
void populateRewriteAsConstantPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that replace tensor ops (such as tensor.generate) with constants whe...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72