MLIR  18.0.0git
RuntimeOpVerification.cpp
Go to the documentation of this file.
1 //===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
16 
17 using namespace mlir;
18 
19 /// Generate an error message string for the given op and the specified error.
20 static std::string generateErrorMessage(Operation *op, const std::string &msg) {
21  std::string buffer;
22  llvm::raw_string_ostream stream(buffer);
23  OpPrintingFlags flags;
24  stream << "ERROR: Runtime op verification failed\n";
25  op->print(stream, flags);
26  stream << "\n^ " << msg;
27  stream << "\nLocation: ";
28  op->getLoc().print(stream);
29  return stream.str();
30 }
31 
32 namespace mlir {
33 namespace memref {
34 namespace {
35 struct CastOpInterface
36  : public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
37  CastOp> {
38  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
39  Location loc) const {
40  auto castOp = cast<CastOp>(op);
41  auto srcType = cast<BaseMemRefType>(castOp.getSource().getType());
42 
43  // Nothing to check if the result is an unranked memref.
44  auto resultType = dyn_cast<MemRefType>(castOp.getType());
45  if (!resultType)
46  return;
47 
48  if (isa<UnrankedMemRefType>(srcType)) {
49  // Check rank.
50  Value srcRank = builder.create<RankOp>(loc, castOp.getSource());
51  Value resultRank =
52  builder.create<arith::ConstantIndexOp>(loc, resultType.getRank());
53  Value isSameRank = builder.create<arith::CmpIOp>(
54  loc, arith::CmpIPredicate::eq, srcRank, resultRank);
55  builder.create<cf::AssertOp>(loc, isSameRank,
56  generateErrorMessage(op, "rank mismatch"));
57  }
58 
59  // Get source offset and strides. We do not have an op to get offsets and
60  // strides from unranked memrefs, so cast the source to a type with fully
61  // dynamic layout, from which we can then extract the offset and strides.
62  // (Rank was already verified.)
63  int64_t dynamicOffset = ShapedType::kDynamic;
64  SmallVector<int64_t> dynamicShape(resultType.getRank(),
65  ShapedType::kDynamic);
66  auto stridedLayout = StridedLayoutAttr::get(builder.getContext(),
67  dynamicOffset, dynamicShape);
68  auto dynStridesType =
69  MemRefType::get(dynamicShape, resultType.getElementType(),
70  stridedLayout, resultType.getMemorySpace());
71  Value helperCast =
72  builder.create<CastOp>(loc, dynStridesType, castOp.getSource());
73  auto metadataOp = builder.create<ExtractStridedMetadataOp>(loc, helperCast);
74 
75  // Check dimension sizes.
76  for (const auto &it : llvm::enumerate(resultType.getShape())) {
77  // Static dim size -> static/dynamic dim size does not need verification.
78  if (auto rankedSrcType = dyn_cast<MemRefType>(srcType))
79  if (!rankedSrcType.isDynamicDim(it.index()))
80  continue;
81 
82  // Static/dynamic dim size -> dynamic dim size does not need verification.
83  if (resultType.isDynamicDim(it.index()))
84  continue;
85 
86  Value srcDimSz =
87  builder.create<DimOp>(loc, castOp.getSource(), it.index());
88  Value resultDimSz =
89  builder.create<arith::ConstantIndexOp>(loc, it.value());
90  Value isSameSz = builder.create<arith::CmpIOp>(
91  loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
92  builder.create<cf::AssertOp>(
93  loc, isSameSz,
94  generateErrorMessage(op, "size mismatch of dim " +
95  std::to_string(it.index())));
96  }
97 
98  // Get result offset and strides.
99  int64_t resultOffset;
100  SmallVector<int64_t> resultStrides;
101  if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
102  return;
103 
104  // Check offset.
105  if (resultOffset != ShapedType::kDynamic) {
106  // Static/dynamic offset -> dynamic offset does not need verification.
107  Value srcOffset = metadataOp.getResult(1);
108  Value resultOffsetVal =
109  builder.create<arith::ConstantIndexOp>(loc, resultOffset);
110  Value isSameOffset = builder.create<arith::CmpIOp>(
111  loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal);
112  builder.create<cf::AssertOp>(loc, isSameOffset,
113  generateErrorMessage(op, "offset mismatch"));
114  }
115 
116  // Check strides.
117  for (const auto &it : llvm::enumerate(resultStrides)) {
118  // Static/dynamic stride -> dynamic stride does not need verification.
119  if (it.value() == ShapedType::kDynamic)
120  continue;
121 
122  Value srcStride =
123  metadataOp.getResult(2 + resultType.getRank() + it.index());
124  Value resultStrideVal =
125  builder.create<arith::ConstantIndexOp>(loc, it.value());
126  Value isSameStride = builder.create<arith::CmpIOp>(
127  loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal);
128  builder.create<cf::AssertOp>(
129  loc, isSameStride,
130  generateErrorMessage(op, "stride mismatch of dim " +
131  std::to_string(it.index())));
132  }
133  }
134 };
135 
136 struct ExpandShapeOpInterface
137  : public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
138  ExpandShapeOp> {
139  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
140  Location loc) const {
141  auto expandShapeOp = cast<ExpandShapeOp>(op);
142 
143  // Verify that the expanded dim sizes are a product of the collapsed dim
144  // size.
145  for (const auto &it :
146  llvm::enumerate(expandShapeOp.getReassociationIndices())) {
147  Value srcDimSz =
148  builder.create<DimOp>(loc, expandShapeOp.getSrc(), it.index());
149  int64_t groupSz = 1;
150  bool foundDynamicDim = false;
151  for (int64_t resultDim : it.value()) {
152  if (expandShapeOp.getResultType().isDynamicDim(resultDim)) {
153  // Keep this assert here in case the op is extended in the future.
154  assert(!foundDynamicDim &&
155  "more than one dynamic dim found in reassoc group");
156  (void)foundDynamicDim;
157  foundDynamicDim = true;
158  continue;
159  }
160  groupSz *= expandShapeOp.getResultType().getDimSize(resultDim);
161  }
162  Value staticResultDimSz =
163  builder.create<arith::ConstantIndexOp>(loc, groupSz);
164  // staticResultDimSz must divide srcDimSz evenly.
165  Value mod =
166  builder.create<arith::RemSIOp>(loc, srcDimSz, staticResultDimSz);
167  Value isModZero = builder.create<arith::CmpIOp>(
168  loc, arith::CmpIPredicate::eq, mod,
169  builder.create<arith::ConstantIndexOp>(loc, 0));
170  builder.create<cf::AssertOp>(
171  loc, isModZero,
172  generateErrorMessage(op, "static result dims in reassoc group do not "
173  "divide src dim evenly"));
174  }
175  }
176 };
177 } // namespace
178 } // namespace memref
179 } // namespace mlir
180 
182  DialectRegistry &registry) {
183  registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
184  CastOp::attachInterface<CastOpInterface>(*ctx);
185  ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
186 
187  // Load additional dialects of which ops may get created.
188  ctx->loadDialect<arith::ArithDialect, cf::ControlFlowDialect>();
189  });
190 }
static std::string generateErrorMessage(Operation *op, const std::string &msg)
Generate an error message string for the given op and the specified error.
MLIRContext * getContext() const
Definition: Builders.h:55
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
void print(raw_ostream &os) const
Print the location.
Definition: Location.h:98
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void loadDialect()
Load a dialect in the context.
Definition: MLIRContext.h:107
This class helps build Operations.
Definition: Builders.h:206
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
Set of flags used to control the behavior of the various IR print methods (e.g.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void print(raw_ostream &os, const OpPrintingFlags &flags=std::nullopt)
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
void registerRuntimeVerifiableOpInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
int64_t mod(int64_t lhs, int64_t rhs)
Returns MLIR's mod operation on constants.
Definition: MathExtras.h:45