MLIR  18.0.0git
TensorTransformOps.cpp
Go to the documentation of this file.
1 //===- TensorTransformOps.cpp - Implementation of tensor transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
18 
19 using namespace mlir;
20 using namespace tensor;
21 
22 //===----------------------------------------------------------------------===//
23 // FindPayloadReplacementOpInterface implementations
24 //===----------------------------------------------------------------------===//
25 
26 namespace {
27 struct ExtractSliceOpReplacementInterface
28  : public transform::FindPayloadReplacementOpInterface::ExternalModel<
29  ExtractSliceOpReplacementInterface, tensor::ExtractSliceOp> {
30  SmallVector<Value> getNextOperands(Operation *op) const {
31  auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
32  if (!isCastLikeExtractSliceOp(extractSliceOp))
33  return {};
34  return {extractSliceOp.getSource()};
35  }
36 };
37 
38 struct InsertSliceOpReplacementInterface
39  : public transform::FindPayloadReplacementOpInterface::ExternalModel<
40  InsertSliceOpReplacementInterface, tensor::InsertSliceOp> {
41  SmallVector<Value> getNextOperands(Operation *op) const {
42  auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
43  if (!isCastLikeInsertSliceOp(insertSliceOp))
44  return {};
45  return {insertSliceOp.getSource()};
46  }
47 };
48 
49 struct ReshapeOpReplacementInterface
50  : public transform::FindPayloadReplacementOpInterface::ExternalModel<
51  ReshapeOpReplacementInterface, tensor::ReshapeOp> {
52  SmallVector<Value> getNextOperands(Operation *op) const {
53  auto reshapeOp = cast<tensor::ReshapeOp>(op);
54  return {reshapeOp.getSource()};
55  }
56 };
57 
58 template <typename ConcreteOp>
59 struct ReassociativeReshapeOpReplacementInterface
60  : public transform::FindPayloadReplacementOpInterface::ExternalModel<
61  ReassociativeReshapeOpReplacementInterface<ConcreteOp>, ConcreteOp> {
62  SmallVector<Value> getNextOperands(Operation *op) const {
63  auto reshapeOp = cast<ConcreteOp>(op);
64  return {reshapeOp.getSrc()};
65  }
66 };
67 } // namespace
68 
70  DialectRegistry &registry) {
71  registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
72  CollapseShapeOp::attachInterface<
73  ReassociativeReshapeOpReplacementInterface<CollapseShapeOp>>(*ctx);
74  ExpandShapeOp::attachInterface<
75  ReassociativeReshapeOpReplacementInterface<ExpandShapeOp>>(*ctx);
76  ExtractSliceOp::attachInterface<ExtractSliceOpReplacementInterface>(*ctx);
77  InsertSliceOp::attachInterface<InsertSliceOpReplacementInterface>(*ctx);
78  ReshapeOp::attachInterface<ReshapeOpReplacementInterface>(*ctx);
79  });
80 }
81 
82 //===----------------------------------------------------------------------===//
83 // Apply...PatternsOp
84 //===----------------------------------------------------------------------===//
85 
86 void transform::ApplyDecomposeTensorConcatPatternsOp::populatePatterns(
87  RewritePatternSet &patterns) {
89 }
90 
91 void transform::ApplyDropRedundantInsertSliceRankExpansionPatternsOp::
92  populatePatterns(RewritePatternSet &patterns) {
94 }
95 
96 void transform::ApplyFoldTensorEmptyPatternsOp::populatePatterns(
97  RewritePatternSet &patterns) {
98  tensor::populateFoldTensorEmptyPatterns(patterns, getFoldSingleUseOnly());
99 }
100 
101 void transform::ApplyFoldIntoPackAndUnpackPatternsOp::populatePatterns(
102  RewritePatternSet &patterns) {
104 }
105 
106 void transform::ApplyFoldTensorSubsetOpsPatternsOp::populatePatterns(
107  RewritePatternSet &patterns) {
109 }
110 
111 void transform::ApplyFoldTensorSubsetOpsIntoVectorTransfersPatternsOp::
112  populatePatterns(RewritePatternSet &patterns) {
114 }
115 
116 void transform::ApplyMergeConsecutiveInsertExtractSlicePatternsOp::
117  populatePatterns(RewritePatternSet &patterns) {
119 }
120 
121 void transform::ApplyReassociativeReshapeFoldingPatternsOp::populatePatterns(
122  RewritePatternSet &patterns) {
124 }
125 
126 void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns(
127  RewritePatternSet &patterns) {
129 }
130 
131 //===----------------------------------------------------------------------===//
132 // MakeLoopIndependentOp
133 //===----------------------------------------------------------------------===//
134 
135 DiagnosedSilenceableFailure transform::MakeLoopIndependentOp::applyToOne(
136  transform::TransformRewriter &rewriter, Operation *target,
138  transform::TransformState &state) {
139  // Gather IVs.
140  SmallVector<Value> ivs;
141  Operation *nextOp = target;
142  for (uint64_t i = 0, e = getNumLoops(); i < e; ++i) {
143  nextOp = nextOp->getParentOfType<scf::ForOp>();
144  if (!nextOp) {
145  DiagnosedSilenceableFailure diag = emitSilenceableError()
146  << "could not find " << i
147  << "-th enclosing loop";
148  diag.attachNote(target->getLoc()) << "target op";
149  return diag;
150  }
151  ivs.push_back(cast<scf::ForOp>(nextOp).getInductionVar());
152  }
153 
154  // Rewrite IR.
155  FailureOr<Value> replacement = failure();
156  if (auto padOp = dyn_cast<tensor::PadOp>(target)) {
157  replacement = tensor::buildIndependentOp(rewriter, padOp, ivs);
158  } else if (auto emptyOp = dyn_cast<tensor::EmptyOp>(target)) {
159  replacement = tensor::buildIndependentOp(rewriter, emptyOp, ivs);
160  } else {
161  DiagnosedSilenceableFailure diag = emitSilenceableError()
162  << "unsupported target op";
163  diag.attachNote(target->getLoc()) << "target op";
164  return diag;
165  }
166  if (failed(replacement)) {
168  emitSilenceableError() << "could not make target op loop-independent";
169  diag.attachNote(target->getLoc()) << "target op";
170  return diag;
171  }
172  rewriter.replaceOp(target, *replacement);
173  results.push_back(replacement->getDefiningOp());
175 }
176 
177 //===----------------------------------------------------------------------===//
178 // Transform op registration
179 //===----------------------------------------------------------------------===//
180 
181 namespace {
182 class TensorTransformDialectExtension
184  TensorTransformDialectExtension> {
185 public:
186  using Base::Base;
187 
188  void init() {
189  declareGeneratedDialect<affine::AffineDialect>();
190  declareGeneratedDialect<tensor::TensorDialect>();
191 
192  registerTransformOps<
193 #define GET_OP_LIST
194 #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.cpp.inc"
195  >();
196  }
197 };
198 } // namespace
199 
200 #define GET_OP_CLASSES
201 #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.cpp.inc"
202 
204  DialectRegistry &registry) {
205  registry.addExtensions<TensorTransformDialectExtension>();
206 }
static std::string diag(const llvm::Value &value)
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
void addExtensions()
Add the given extensions to the registry.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void push_back(Operation *op)
Appends an element to the list.
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...
void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns, bool foldSingleUseOnly=false)
Populates patterns with patterns that fold tensor.empty with tensor.
void registerTransformDialectExtension(DialectRegistry &registry)
void populateMergeConsecutiveInsertExtractSlicePatterns(RewritePatternSet &patterns)
Collects patterns to merge consecutive tensor.insert_slice/extract_slice into one.
bool isCastLikeInsertSliceOp(InsertSliceOp op)
A tensor.insert_slice is a cast-like operation if it merely rank-extends the source tensor or inserts...
Definition: Utils.cpp:76
void populateDecomposeTensorConcatPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that decompose tensor.concat into tensor.empty of a tensor of the co...
void populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into consumer load/store ops into patterns.
void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold tensor.expand_shape and tensor.collapse_shape into other o...
void registerFindPayloadReplacementOpInterfaceExternalModels(DialectRegistry &registry)
bool isCastLikeExtractSliceOp(ExtractSliceOp op)
A tensor.extract_slice is a cast-like operation if it merely rank-reduces unit dimensions of the sour...
Definition: Utils.cpp:96
void populateDropRedundantInsertSliceRankExpansionPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that drop redundant tensor.insert_slice rank expansions.
FailureOr< Value > buildIndependentOp(OpBuilder &b, tensor::PadOp padOp, ValueRange independencies)
Build a new tensor::PadOp with low/high padding that is independent of all given independencies.
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
void populateRewriteAsConstantPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that replace tensor ops (such as tensor.generate) with constants whe...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72