MLIR  22.0.0git
PtrDialect.cpp
Go to the documentation of this file.
1 //===- PtrDialect.cpp - Pointer dialect ---------------------*- C++ -*-===//
2 //
3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Pointer dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
15 #include "mlir/IR/Matchers.h"
18 #include "llvm/ADT/TypeSwitch.h"
19 
20 using namespace mlir;
21 using namespace mlir::ptr;
22 
23 //===----------------------------------------------------------------------===//
24 // Pointer dialect
25 //===----------------------------------------------------------------------===//
26 
27 void PtrDialect::initialize() {
28  addOperations<
29 #define GET_OP_LIST
30 #include "mlir/Dialect/Ptr/IR/PtrOps.cpp.inc"
31  >();
32  addAttributes<
33 #define GET_ATTRDEF_LIST
34 #include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.cpp.inc"
35  >();
36  addTypes<
37 #define GET_TYPEDEF_LIST
38 #include "mlir/Dialect/Ptr/IR/PtrOpsTypes.cpp.inc"
39  >();
40 }
41 
42 //===----------------------------------------------------------------------===//
43 // FromPtrOp
44 //===----------------------------------------------------------------------===//
45 
46 OpFoldResult FromPtrOp::fold(FoldAdaptor adaptor) {
47  // Fold the pattern:
48  // %ptr = ptr.to_ptr %v : type -> ptr
49  // (%mda = ptr.get_metadata %v : type)?
50  // %val = ptr.from_ptr %ptr (metadata %mda)? : ptr -> type
51  // To:
52  // %val -> %v
53  Value ptrLike;
54  FromPtrOp fromPtr = *this;
55  while (fromPtr != nullptr) {
56  auto toPtr = fromPtr.getPtr().getDefiningOp<ToPtrOp>();
57  // Cannot fold if it's not a `to_ptr` op or the initial and final types are
58  // different.
59  if (!toPtr || toPtr.getPtr().getType() != fromPtr.getType())
60  return ptrLike;
61  Value md = fromPtr.getMetadata();
62  // If the type has trivial metadata fold.
63  if (!fromPtr.getType().hasPtrMetadata()) {
64  ptrLike = toPtr.getPtr();
65  } else if (md) {
66  // Fold if the metadata can be verified to be equal.
67  if (auto mdOp = md.getDefiningOp<GetMetadataOp>();
68  mdOp && mdOp.getPtr() == toPtr.getPtr())
69  ptrLike = toPtr.getPtr();
70  }
71  // Check for a sequence of casts.
72  fromPtr = ptrLike ? ptrLike.getDefiningOp<FromPtrOp>() : nullptr;
73  }
74  return ptrLike;
75 }
76 
77 LogicalResult FromPtrOp::verify() {
78  if (isa<PtrType>(getType()))
79  return emitError() << "the result type cannot be `!ptr.ptr`";
80  if (getType().getMemorySpace() != getPtr().getType().getMemorySpace()) {
81  return emitError()
82  << "expected the input and output to have the same memory space";
83  }
84  return success();
85 }
86 
87 //===----------------------------------------------------------------------===//
88 // PtrAddOp
89 //===----------------------------------------------------------------------===//
90 
91 /// Fold: ptradd ptr + 0 -> ptr
92 OpFoldResult PtrAddOp::fold(FoldAdaptor adaptor) {
93  Attribute attr = adaptor.getOffset();
94  if (!attr)
95  return nullptr;
96  if (llvm::APInt value; m_ConstantInt(&value).match(attr) && value.isZero())
97  return getBase();
98  return nullptr;
99 }
100 
101 //===----------------------------------------------------------------------===//
102 // ToPtrOp
103 //===----------------------------------------------------------------------===//
104 
105 OpFoldResult ToPtrOp::fold(FoldAdaptor adaptor) {
106  // Fold the pattern:
107  // %val = ptr.from_ptr %p (metadata ...)? : ptr -> type
108  // %ptr = ptr.to_ptr %val : type -> ptr
109  // To:
110  // %ptr -> %p
111  Value ptr;
112  ToPtrOp toPtr = *this;
113  while (toPtr != nullptr) {
114  auto fromPtr = toPtr.getPtr().getDefiningOp<FromPtrOp>();
115  // Cannot fold if it's not a `from_ptr` op.
116  if (!fromPtr)
117  return ptr;
118  ptr = fromPtr.getPtr();
119  // Check for chains of casts.
120  toPtr = ptr.getDefiningOp<ToPtrOp>();
121  }
122  return ptr;
123 }
124 
125 LogicalResult ToPtrOp::verify() {
126  if (isa<PtrType>(getPtr().getType()))
127  return emitError() << "the input value cannot be of type `!ptr.ptr`";
128  if (getType().getMemorySpace() != getPtr().getType().getMemorySpace()) {
129  return emitError()
130  << "expected the input and output to have the same memory space";
131  }
132  return success();
133 }
134 
135 //===----------------------------------------------------------------------===//
136 // TypeOffsetOp
137 //===----------------------------------------------------------------------===//
138 
139 llvm::TypeSize TypeOffsetOp::getTypeSize(std::optional<DataLayout> layout) {
140  if (layout)
141  return layout->getTypeSize(getElementType());
142  DataLayout dl = DataLayout::closest(*this);
143  return dl.getTypeSize(getElementType());
144 }
145 
146 //===----------------------------------------------------------------------===//
147 // Pointer API.
148 //===----------------------------------------------------------------------===//
149 
150 #include "mlir/Dialect/Ptr/IR/PtrOpsDialect.cpp.inc"
151 
152 #define GET_ATTRDEF_CLASSES
153 #include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.cpp.inc"
154 
155 #include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.cpp.inc"
156 
157 #include "mlir/Dialect/Ptr/IR/MemorySpaceAttrInterfaces.cpp.inc"
158 
159 #include "mlir/Dialect/Ptr/IR/PtrOpsEnums.cpp.inc"
160 
161 #define GET_TYPEDEF_CLASSES
162 #include "mlir/Dialect/Ptr/IR/PtrOpsTypes.cpp.inc"
163 
164 #define GET_OP_CLASSES
165 #include "mlir/Dialect/Ptr/IR/PtrOps.cpp.inc"
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static Type getElementType(Type type)
Determine the element type of type.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
Include the generated interface declarations.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423