MLIR  20.0.0git
SparseTensorTransformOps.cpp
Go to the documentation of this file.
1 //===- SparseTensorTransformOps.cpp - sparse tensor transform ops impl ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 
13 using namespace mlir;
14 using namespace mlir::sparse_tensor;
15 
16 //===----------------------------------------------------------------------===//
17 // Transform op implementation
18 //===----------------------------------------------------------------------===//
19 
20 DiagnosedSilenceableFailure transform::MatchSparseInOut::matchOperation(
23  bool hasSparseInOut = hasAnySparseOperandOrResult(current);
24  if (!hasSparseInOut) {
25  return emitSilenceableFailure(current->getLoc(),
26  "operation has no sparse input or output");
27  }
28  results.set(cast<OpResult>(getResult()), state.getPayloadOps(getTarget()));
30 }
31 
32 //===----------------------------------------------------------------------===//
33 // Transform op registration
34 //===----------------------------------------------------------------------===//
35 
36 namespace {
37 class SparseTensorTransformDialectExtension
39  SparseTensorTransformDialectExtension> {
40 public:
42  SparseTensorTransformDialectExtension)
43 
44  SparseTensorTransformDialectExtension() {
45  declareGeneratedDialect<sparse_tensor::SparseTensorDialect>();
46  registerTransformOps<
47 #define GET_OP_LIST
48 #include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.cpp.inc"
49  >();
50  }
51 };
52 } // namespace
53 
54 #define GET_OP_CLASSES
55 #include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.cpp.inc"
56 
58  DialectRegistry &registry) {
59  registry.addExtensions<SparseTensorTransformDialectExtension>();
60 }
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:274
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
The state maintained across applications of various ops implementing the TransformOpInterface.
bool hasAnySparseOperandOrResult(Operation *op)
Returns true iff MLIR operand has any sparse operand or result.
Definition: SparseTensor.h:196
void registerTransformDialectExtension(DialectRegistry &registry)
Include the generated interface declarations.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.