MLIR  19.0.0git
CastInterfaces.cpp
Go to the documentation of this file.
1 //===- CastInterfaces.cpp -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 #include "mlir/IR/BuiltinDialect.h"
12 #include "mlir/IR/BuiltinOps.h"
13 
14 using namespace mlir;
15 
16 //===----------------------------------------------------------------------===//
17 // Helper functions for CastOpInterface
18 //===----------------------------------------------------------------------===//
19 
20 /// Attempt to fold the given cast operation.
23  SmallVectorImpl<OpFoldResult> &foldResults) {
24  OperandRange operands = op->getOperands();
25  if (operands.empty())
26  return failure();
27  ResultRange results = op->getResults();
28 
29  // Check for the case where the input and output types match 1-1.
30  if (operands.getTypes() == results.getTypes()) {
31  foldResults.append(operands.begin(), operands.end());
32  return success();
33  }
34 
35  return failure();
36 }
37 
38 /// Attempt to verify the given cast operation.
40  auto resultTypes = op->getResultTypes();
41  if (resultTypes.empty())
42  return op->emitOpError()
43  << "expected at least one result for cast operation";
44 
45  auto operandTypes = op->getOperandTypes();
46  if (!cast<CastOpInterface>(op).areCastCompatible(operandTypes, resultTypes)) {
47  InFlightDiagnostic diag = op->emitOpError("operand type");
48  if (operandTypes.empty())
49  diag << "s []";
50  else if (llvm::size(operandTypes) == 1)
51  diag << " " << *operandTypes.begin();
52  else
53  diag << "s " << operandTypes;
54  return diag << " and result type" << (resultTypes.size() == 1 ? " " : "s ")
55  << resultTypes << " are cast incompatible";
56  }
57 
58  return success();
59 }
60 
61 //===----------------------------------------------------------------------===//
62 // External model for BuiltinDialect ops
63 //===----------------------------------------------------------------------===//
64 
65 namespace mlir {
66 namespace {
67 // This interface cannot be implemented directly on the op because the IR build
68 // unit cannot depend on the Interfaces build unit.
69 struct UnrealizedConversionCastOpInterface
70  : CastOpInterface::ExternalModel<UnrealizedConversionCastOpInterface,
71  UnrealizedConversionCastOp> {
72  static bool areCastCompatible(TypeRange inputs, TypeRange outputs) {
73  // `UnrealizedConversionCastOp` is agnostic of the input/output types.
74  return true;
75  }
76 };
77 } // namespace
78 } // namespace mlir
79 
81  DialectRegistry &registry) {
82  registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
83  UnrealizedConversionCastOp::attachInterface<
84  UnrealizedConversionCastOpInterface>(*ctx);
85  });
86 }
87 
88 //===----------------------------------------------------------------------===//
89 // Table-generated class definitions
90 //===----------------------------------------------------------------------===//
91 
92 #include "mlir/Interfaces/CastInterfaces.cpp.inc"
static std::string diag(const llvm::Value &value)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
type_range getTypes() const
Definition: ValueRange.cpp:26
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class implements the result iterators for the Operation class.
Definition: ValueRange.h:242
type_range getTypes() const
Definition: ValueRange.cpp:35
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
void registerCastOpInterfaceExternalModels(DialectRegistry &registry)
LogicalResult foldCastInterfaceOp(Operation *op, ArrayRef< Attribute > attrOperands, SmallVectorImpl< OpFoldResult > &foldResults)
Attempt to fold the given cast operation.
LogicalResult verifyCastInterfaceOp(Operation *op)
Attempt to verify the given cast operation.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26