MLIR  21.0.0git
PtrDialect.cpp
Go to the documentation of this file.
1 //===- PtrDialect.cpp - Pointer dialect ---------------------*- C++ -*-===//
2 //
3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Pointer dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
15 #include "mlir/IR/Matchers.h"
16 #include "mlir/IR/PatternMatch.h"
19 #include "llvm/ADT/SmallString.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 
22 using namespace mlir;
23 using namespace mlir::ptr;
24 
25 //===----------------------------------------------------------------------===//
26 // Pointer dialect
27 //===----------------------------------------------------------------------===//
28 
29 void PtrDialect::initialize() {
30  addOperations<
31 #define GET_OP_LIST
32 #include "mlir/Dialect/Ptr/IR/PtrOps.cpp.inc"
33  >();
34  addAttributes<
35 #define GET_ATTRDEF_LIST
36 #include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.cpp.inc"
37  >();
38  addTypes<
39 #define GET_TYPEDEF_LIST
40 #include "mlir/Dialect/Ptr/IR/PtrOpsTypes.cpp.inc"
41  >();
42 }
43 
44 //===----------------------------------------------------------------------===//
45 // FromPtrOp
46 //===----------------------------------------------------------------------===//
47 
48 OpFoldResult FromPtrOp::fold(FoldAdaptor adaptor) {
49  // Fold the pattern:
50  // %ptr = ptr.to_ptr %v : type -> ptr
51  // (%mda = ptr.get_metadata %v : type)?
52  // %val = ptr.from_ptr %ptr (metadata %mda)? : ptr -> type
53  // To:
54  // %val -> %v
55  Value ptrLike;
56  FromPtrOp fromPtr = *this;
57  while (fromPtr != nullptr) {
58  auto toPtr = dyn_cast_or_null<ToPtrOp>(fromPtr.getPtr().getDefiningOp());
59  // Cannot fold if it's not a `to_ptr` op or the initial and final types are
60  // different.
61  if (!toPtr || toPtr.getPtr().getType() != fromPtr.getType())
62  return ptrLike;
63  Value md = fromPtr.getMetadata();
64  // If the type has trivial metadata fold.
65  if (!fromPtr.getType().hasPtrMetadata()) {
66  ptrLike = toPtr.getPtr();
67  } else if (md) {
68  // Fold if the metadata can be verified to be equal.
69  if (auto mdOp = dyn_cast_or_null<GetMetadataOp>(md.getDefiningOp());
70  mdOp && mdOp.getPtr() == toPtr.getPtr())
71  ptrLike = toPtr.getPtr();
72  }
73  // Check for a sequence of casts.
74  fromPtr = dyn_cast_or_null<FromPtrOp>(ptrLike ? ptrLike.getDefiningOp()
75  : nullptr);
76  }
77  return ptrLike;
78 }
79 
80 LogicalResult FromPtrOp::verify() {
81  if (isa<PtrType>(getType()))
82  return emitError() << "the result type cannot be `!ptr.ptr`";
83  if (getType().getMemorySpace() != getPtr().getType().getMemorySpace()) {
84  return emitError()
85  << "expected the input and output to have the same memory space";
86  }
87  return success();
88 }
89 
90 //===----------------------------------------------------------------------===//
91 // PtrAddOp
92 //===----------------------------------------------------------------------===//
93 
94 /// Fold: ptradd ptr + 0 -> ptr
95 OpFoldResult PtrAddOp::fold(FoldAdaptor adaptor) {
96  Attribute attr = adaptor.getOffset();
97  if (!attr)
98  return nullptr;
99  if (llvm::APInt value; m_ConstantInt(&value).match(attr) && value.isZero())
100  return getBase();
101  return nullptr;
102 }
103 
104 //===----------------------------------------------------------------------===//
105 // ToPtrOp
106 //===----------------------------------------------------------------------===//
107 
108 OpFoldResult ToPtrOp::fold(FoldAdaptor adaptor) {
109  // Fold the pattern:
110  // %val = ptr.from_ptr %p (metadata ...)? : ptr -> type
111  // %ptr = ptr.to_ptr %val : type -> ptr
112  // To:
113  // %ptr -> %p
114  Value ptr;
115  ToPtrOp toPtr = *this;
116  while (toPtr != nullptr) {
117  auto fromPtr = dyn_cast_or_null<FromPtrOp>(toPtr.getPtr().getDefiningOp());
118  // Cannot fold if it's not a `from_ptr` op.
119  if (!fromPtr)
120  return ptr;
121  ptr = fromPtr.getPtr();
122  // Check for chains of casts.
123  toPtr = dyn_cast_or_null<ToPtrOp>(ptr.getDefiningOp());
124  }
125  return ptr;
126 }
127 
128 LogicalResult ToPtrOp::verify() {
129  if (isa<PtrType>(getPtr().getType()))
130  return emitError() << "the input value cannot be of type `!ptr.ptr`";
131  if (getType().getMemorySpace() != getPtr().getType().getMemorySpace()) {
132  return emitError()
133  << "expected the input and output to have the same memory space";
134  }
135  return success();
136 }
137 
138 //===----------------------------------------------------------------------===//
139 // TypeOffsetOp
140 //===----------------------------------------------------------------------===//
141 
142 llvm::TypeSize TypeOffsetOp::getTypeSize(std::optional<DataLayout> layout) {
143  if (layout)
144  return layout->getTypeSize(getElementType());
145  DataLayout dl = DataLayout::closest(*this);
146  return dl.getTypeSize(getElementType());
147 }
148 
149 //===----------------------------------------------------------------------===//
150 // Pointer API.
151 //===----------------------------------------------------------------------===//
152 
153 #include "mlir/Dialect/Ptr/IR/PtrOpsDialect.cpp.inc"
154 
155 #define GET_ATTRDEF_CLASSES
156 #include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.cpp.inc"
157 
158 #include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.cpp.inc"
159 
160 #include "mlir/Dialect/Ptr/IR/MemorySpaceAttrInterfaces.cpp.inc"
161 
162 #include "mlir/Dialect/Ptr/IR/PtrOpsEnums.cpp.inc"
163 
164 #define GET_TYPEDEF_CLASSES
165 #include "mlir/Dialect/Ptr/IR/PtrOpsTypes.cpp.inc"
166 
167 #define GET_OP_CLASSES
168 #include "mlir/Dialect/Ptr/IR/PtrOps.cpp.inc"
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:188
Attributes are known-constant values of operations.
Definition: Attributes.h:25
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Include the generated interface declarations.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423