MLIR 22.0.0git
TensorTransformOps.cpp
Go to the documentation of this file.
1//===- TensorTransformOps.cpp - Implementation of tensor transform ops ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
18#include "mlir/IR/Builders.h"
20
21using namespace mlir;
22using namespace tensor;
23
24//===----------------------------------------------------------------------===//
25// FindPayloadReplacementOpInterface implementations
26//===----------------------------------------------------------------------===//
27
28namespace {
29struct ExtractSliceOpReplacementInterface
30 : public transform::FindPayloadReplacementOpInterface::ExternalModel<
31 ExtractSliceOpReplacementInterface, tensor::ExtractSliceOp> {
32 SmallVector<Value> getNextOperands(Operation *op) const {
33 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
34 if (!isCastLikeExtractSliceOp(extractSliceOp))
35 return {};
36 return {extractSliceOp.getSource()};
37 }
38};
39
40struct InsertSliceOpReplacementInterface
41 : public transform::FindPayloadReplacementOpInterface::ExternalModel<
42 InsertSliceOpReplacementInterface, tensor::InsertSliceOp> {
43 SmallVector<Value> getNextOperands(Operation *op) const {
44 auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
45 if (!isCastLikeInsertSliceOp(insertSliceOp))
46 return {};
47 return {insertSliceOp.getSource()};
48 }
49};
50
51struct ReshapeOpReplacementInterface
52 : public transform::FindPayloadReplacementOpInterface::ExternalModel<
53 ReshapeOpReplacementInterface, tensor::ReshapeOp> {
54 SmallVector<Value> getNextOperands(Operation *op) const {
55 auto reshapeOp = cast<tensor::ReshapeOp>(op);
56 return {reshapeOp.getSource()};
57 }
58};
59
60template <typename ConcreteOp>
61struct ReassociativeReshapeOpReplacementInterface
62 : public transform::FindPayloadReplacementOpInterface::ExternalModel<
63 ReassociativeReshapeOpReplacementInterface<ConcreteOp>, ConcreteOp> {
64 SmallVector<Value> getNextOperands(Operation *op) const {
65 auto reshapeOp = cast<ConcreteOp>(op);
66 return {reshapeOp.getSrc()};
67 }
68};
69} // namespace
70
72 DialectRegistry &registry) {
73 registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
74 CollapseShapeOp::attachInterface<
75 ReassociativeReshapeOpReplacementInterface<CollapseShapeOp>>(*ctx);
76 ExpandShapeOp::attachInterface<
77 ReassociativeReshapeOpReplacementInterface<ExpandShapeOp>>(*ctx);
78 ExtractSliceOp::attachInterface<ExtractSliceOpReplacementInterface>(*ctx);
79 InsertSliceOp::attachInterface<InsertSliceOpReplacementInterface>(*ctx);
80 ReshapeOp::attachInterface<ReshapeOpReplacementInterface>(*ctx);
81 });
82}
83
84//===----------------------------------------------------------------------===//
85// Apply...PatternsOp
86//===----------------------------------------------------------------------===//
87
88void transform::ApplyDecomposeTensorConcatPatternsOp::populatePatterns(
91}
92
93void transform::ApplyDropRedundantInsertSliceRankExpansionPatternsOp::
94 populatePatterns(RewritePatternSet &patterns) {
96}
97
98void transform::ApplyFoldTensorEmptyPatternsOp::populatePatterns(
100 tensor::populateFoldTensorEmptyPatterns(patterns, getFoldSingleUseOnly());
101}
102
103void transform::ApplyFoldTensorSubsetOpsPatternsOp::populatePatterns(
106}
107
108void transform::ApplyFoldTensorSubsetOpsIntoVectorTransfersPatternsOp::
109 populatePatterns(RewritePatternSet &patterns) {
111}
112
113void transform::ApplyMergeConsecutiveInsertExtractSlicePatternsOp::
114 populatePatterns(RewritePatternSet &patterns) {
116}
117
118void transform::ApplyReassociativeReshapeFoldingPatternsOp::populatePatterns(
121}
122
123void transform::ApplyBubbleUpExtractSlicePatternsOp::populatePatterns(
126}
127
128void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns(
130 ControlFoldFn defaultControlFn = [](OpOperand *fusedOperand) {
131 Operation *producer = fusedOperand->get().getDefiningOp();
132 return producer && producer->hasOneUse();
133 };
134
135 ControlFoldFn aggressiveControlFn = [](OpOperand *fusedOperand) {
136 return true;
137 };
138
139 // Add folding with reshape by expansion patterns.
140 if (getAggressive())
142 else
144}
145
146//===----------------------------------------------------------------------===//
147// TypeConversionCastTensorShapeOp
148//===----------------------------------------------------------------------===//
149
150void transform::TypeConversionCastShapeDynamicDimsOp::
151 populateTypeMaterializations(TypeConverter &converter) {
152 bool ignoreDynamicInfo = getIgnoreDynamicInfo();
153 converter.addSourceMaterialization([ignoreDynamicInfo](
154 OpBuilder &builder, Type resultType,
155 ValueRange inputs,
156 Location loc) -> Value {
157 if (inputs.size() != 1) {
158 return Value();
159 }
160 Value input = inputs[0];
161 if (!ignoreDynamicInfo &&
162 !tensor::preservesStaticInformation(resultType, input.getType())) {
163 return Value();
164 }
165 if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) {
166 return Value();
167 }
168 return tensor::CastOp::create(builder, loc, resultType, input).getResult();
169 });
170 converter.addTargetMaterialization([](OpBuilder &builder, Type resultType,
171 ValueRange inputs,
172 Location loc) -> Value {
173 if (inputs.size() != 1) {
174 return Value();
175 }
176 Value input = inputs[0];
177 if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) {
178 return Value();
179 }
180 return tensor::CastOp::create(builder, loc, resultType, input).getResult();
181 });
182}
183
184//===----------------------------------------------------------------------===//
185// MakeLoopIndependentOp
186//===----------------------------------------------------------------------===//
187
188DiagnosedSilenceableFailure transform::MakeLoopIndependentOp::applyToOne(
192 // Gather IVs.
194 Operation *nextOp = target;
195 for (uint64_t i = 0, e = getNumLoops(); i < e; ++i) {
196 nextOp = nextOp->getParentOfType<scf::ForOp>();
197 if (!nextOp) {
198 DiagnosedSilenceableFailure diag = emitSilenceableError()
199 << "could not find " << i
200 << "-th enclosing loop";
201 diag.attachNote(target->getLoc()) << "target op";
202 return diag;
203 }
204 ivs.push_back(cast<scf::ForOp>(nextOp).getInductionVar());
205 }
206
207 // Rewrite IR.
208 FailureOr<Value> replacement = failure();
209 if (auto padOp = dyn_cast<tensor::PadOp>(target)) {
210 replacement = tensor::buildIndependentOp(rewriter, padOp, ivs);
211 } else if (auto emptyOp = dyn_cast<tensor::EmptyOp>(target)) {
212 replacement = tensor::buildIndependentOp(rewriter, emptyOp, ivs);
213 } else {
214 DiagnosedSilenceableFailure diag = emitSilenceableError()
215 << "unsupported target op";
216 diag.attachNote(target->getLoc()) << "target op";
217 return diag;
218 }
219 if (failed(replacement)) {
221 emitSilenceableError() << "could not make target op loop-independent";
222 diag.attachNote(target->getLoc()) << "target op";
223 return diag;
224 }
225 rewriter.replaceOp(target, *replacement);
226 results.push_back(replacement->getDefiningOp());
228}
229
230//===----------------------------------------------------------------------===//
231// Transform op registration
232//===----------------------------------------------------------------------===//
233
234namespace {
235class TensorTransformDialectExtension
237 TensorTransformDialectExtension> {
238public:
239 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TensorTransformDialectExtension)
240
241 using Base::Base;
242
243 void init() {
244 declareGeneratedDialect<affine::AffineDialect>();
245 declareGeneratedDialect<tensor::TensorDialect>();
246
247 registerTransformOps<
248#define GET_OP_LIST
249#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.cpp.inc"
250 >();
251 }
252};
253} // namespace
254
255#define GET_OP_CLASSES
256#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.cpp.inc"
257
259 DialectRegistry &registry) {
260 registry.addExtensions<TensorTransformDialectExtension>();
261}
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static std::string diag(const llvm::Value &value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
void addExtensions()
Add the given extensions to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
This class represents an operand of an operation.
Definition Value.h:257
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition Operation.h:849
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:238
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void push_back(Operation *op)
Appends an element to the list.
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns, bool foldSingleUseOnly=false)
Populates patterns with patterns that fold tensor.empty with its consumers.
void registerTransformDialectExtension(DialectRegistry &registry)
void populateMergeConsecutiveInsertExtractSlicePatterns(RewritePatternSet &patterns)
Collects patterns to merge consecutive tensor.insert_slice/extract_slice into one.
bool isCastLikeInsertSliceOp(InsertSliceOp op)
A tensor.insert_slice is a cast-like operation if it merely rank-extends the source tensor or inserts...
Definition Utils.cpp:125
void populateDecomposeTensorConcatPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that decompose tensor.concat into tensor.empty of a tensor of the co...
void populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into consumer load/store ops into patterns.
void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold tensor.expand_shape and tensor.collapse_shape into other o...
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Appends patterns that are used to bubble up tensor.extract slice op above its producer.
std::function< bool(OpOperand *)> ControlFoldFn
Definition Transforms.h:99
void registerFindPayloadReplacementOpInterfaceExternalModels(DialectRegistry &registry)
bool isCastLikeExtractSliceOp(ExtractSliceOp op)
A tensor.extract_slice is a cast-like operation if it merely rank-reduces unit dimensions of the sour...
Definition Utils.cpp:149
void populateDropRedundantInsertSliceRankExpansionPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that drop redundant tensor.insert_slice rank expansions.
FailureOr< Value > buildIndependentOp(OpBuilder &b, tensor::PadOp padOp, ValueRange independencies)
Build a new tensor::PadOp with low/high padding that is independent of all given independencies.
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
void populateRewriteAsConstantPatterns(RewritePatternSet &patterns, const ControlFoldFn &controlFn)
Populates patterns with patterns that replace tensor ops (such as tensor.generate) with constants whe...
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns