MLIR 22.0.0git
PtrDialect.cpp
Go to the documentation of this file.
1//===- PtrDialect.cpp - Pointer dialect ---------------------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the Pointer dialect.
10//
11//===----------------------------------------------------------------------===//
12
15#include "mlir/IR/Matchers.h"
18#include "llvm/ADT/StringExtras.h"
19#include "llvm/ADT/TypeSwitch.h"
20
21using namespace mlir;
22using namespace mlir::ptr;
23
24//===----------------------------------------------------------------------===//
25// Pointer dialect
26//===----------------------------------------------------------------------===//
27
28void PtrDialect::initialize() {
29 addOperations<
30#define GET_OP_LIST
31#include "mlir/Dialect/Ptr/IR/PtrOps.cpp.inc"
32 >();
33 addAttributes<
34#define GET_ATTRDEF_LIST
35#include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.cpp.inc"
36 >();
37 addTypes<
38#define GET_TYPEDEF_LIST
39#include "mlir/Dialect/Ptr/IR/PtrOpsTypes.cpp.inc"
40 >();
41}
42
43//===----------------------------------------------------------------------===//
44// Common helper functions.
45//===----------------------------------------------------------------------===//
46
47/// Verifies that the alignment attribute is a power of 2 if present.
48static LogicalResult
49verifyAlignment(std::optional<int64_t> alignment,
51 if (!alignment)
52 return success();
53 if (alignment.value() <= 0)
54 return emitError() << "alignment must be positive";
55 if (!llvm::isPowerOf2_64(alignment.value()))
56 return emitError() << "alignment must be a power of 2";
57 return success();
58}
59
60//===----------------------------------------------------------------------===//
61// ConstantOp
62//===----------------------------------------------------------------------===//
63
64OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
65
66//===----------------------------------------------------------------------===//
67// FromPtrOp
68//===----------------------------------------------------------------------===//
69
70OpFoldResult FromPtrOp::fold(FoldAdaptor adaptor) {
71 // Fold the pattern:
72 // %ptr = ptr.to_ptr %v : type -> ptr
73 // (%mda = ptr.get_metadata %v : type)?
74 // %val = ptr.from_ptr %ptr (metadata %mda)? : ptr -> type
75 // To:
76 // %val -> %v
77 Value ptrLike;
78 FromPtrOp fromPtr = *this;
79 while (fromPtr != nullptr) {
80 auto toPtr = fromPtr.getPtr().getDefiningOp<ToPtrOp>();
81 // Cannot fold if it's not a `to_ptr` op or the initial and final types are
82 // different.
83 if (!toPtr || toPtr.getPtr().getType() != fromPtr.getType())
84 return ptrLike;
85 Value md = fromPtr.getMetadata();
86 // If the type has trivial metadata fold.
87 if (!fromPtr.getType().hasPtrMetadata()) {
88 ptrLike = toPtr.getPtr();
89 } else if (md) {
90 // Fold if the metadata can be verified to be equal.
91 if (auto mdOp = md.getDefiningOp<GetMetadataOp>();
92 mdOp && mdOp.getPtr() == toPtr.getPtr())
93 ptrLike = toPtr.getPtr();
94 }
95 // Check for a sequence of casts.
96 fromPtr = ptrLike ? ptrLike.getDefiningOp<FromPtrOp>() : nullptr;
97 }
98 return ptrLike;
99}
100
101LogicalResult FromPtrOp::verify() {
102 if (isa<PtrType>(getType()))
103 return emitError() << "the result type cannot be `!ptr.ptr`";
104 if (getType().getMemorySpace() != getPtr().getType().getMemorySpace()) {
105 return emitError()
106 << "expected the input and output to have the same memory space";
107 }
108 return success();
109}
110
111//===----------------------------------------------------------------------===//
112// GatherOp
113//===----------------------------------------------------------------------===//
114
115void GatherOp::getEffects(
117 &effects) {
118 // Gather performs reads from multiple memory locations specified by ptrs
119 effects.emplace_back(MemoryEffects::Read::get(), &getPtrsMutable());
120}
121
122LogicalResult GatherOp::verify() {
123 auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };
124
125 // Verify that the pointer type's memory space allows loads.
126 MemorySpaceAttrInterface ms =
127 cast<PtrType>(getPtrs().getType().getElementType()).getMemorySpace();
128 DataLayout dataLayout = DataLayout::closest(*this);
129 if (!ms.isValidLoad(getResult().getType(), AtomicOrdering::not_atomic,
130 getAlignment(), &dataLayout, emitDiag))
131 return failure();
132
133 // Verify the alignment.
134 return verifyAlignment(getAlignment(), emitDiag);
135}
136
137void GatherOp::build(OpBuilder &builder, OperationState &state, Type resultType,
138 Value ptrs, Value mask, Value passthrough,
139 unsigned alignment) {
140 build(builder, state, resultType, ptrs, mask, passthrough,
141 alignment ? std::optional<int64_t>(alignment) : std::nullopt);
142}
143
144//===----------------------------------------------------------------------===//
145// LoadOp
146//===----------------------------------------------------------------------===//
147
148/// Verifies the attributes and the type of atomic memory access operations.
149template <typename OpTy>
150static LogicalResult
151verifyAtomicMemOp(OpTy memOp, ArrayRef<AtomicOrdering> unsupportedOrderings) {
152 if (memOp.getOrdering() != AtomicOrdering::not_atomic) {
153 if (llvm::is_contained(unsupportedOrderings, memOp.getOrdering()))
154 return memOp.emitOpError("unsupported ordering '")
155 << stringifyAtomicOrdering(memOp.getOrdering()) << "'";
156 if (!memOp.getAlignment())
157 return memOp.emitOpError("expected alignment for atomic access");
158 return success();
159 }
160 if (memOp.getSyncscope()) {
161 return memOp.emitOpError(
162 "expected syncscope to be null for non-atomic access");
163 }
164 return success();
165}
166
167void LoadOp::getEffects(
169 &effects) {
170 effects.emplace_back(MemoryEffects::Read::get(), &getPtrMutable());
171 // Volatile operations can have target-specific read-write effects on
172 // memory besides the one referred to by the pointer operand.
173 // Similarly, atomic operations that are monotonic or stricter cause
174 // synchronization that from a language point-of-view, are arbitrary
175 // read-writes into memory.
176 if (getVolatile_() || (getOrdering() != AtomicOrdering::not_atomic &&
177 getOrdering() != AtomicOrdering::unordered)) {
178 effects.emplace_back(MemoryEffects::Write::get());
179 effects.emplace_back(MemoryEffects::Read::get());
180 }
181}
182
183LogicalResult LoadOp::verify() {
184 auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };
185 MemorySpaceAttrInterface ms = getPtr().getType().getMemorySpace();
186 DataLayout dataLayout = DataLayout::closest(*this);
187 if (!ms.isValidLoad(getResult().getType(), getOrdering(), getAlignment(),
188 &dataLayout, emitDiag))
189 return failure();
190 if (failed(verifyAlignment(getAlignment(), emitDiag)))
191 return failure();
192 return verifyAtomicMemOp(*this,
193 {AtomicOrdering::release, AtomicOrdering::acq_rel});
194}
195
196void LoadOp::build(OpBuilder &builder, OperationState &state, Type type,
197 Value addr, unsigned alignment, bool isVolatile,
198 bool isNonTemporal, bool isInvariant, bool isInvariantGroup,
199 AtomicOrdering ordering, StringRef syncscope) {
200 build(builder, state, type, addr,
201 alignment ? std::optional<int64_t>(alignment) : std::nullopt,
202 isVolatile, isNonTemporal, isInvariant, isInvariantGroup, ordering,
203 syncscope.empty() ? nullptr : builder.getStringAttr(syncscope));
204}
205//===----------------------------------------------------------------------===//
206// MaskedLoadOp
207//===----------------------------------------------------------------------===//
208
209void MaskedLoadOp::getEffects(
211 &effects) {
212 // MaskedLoad performs reads from the memory location specified by ptr.
213 effects.emplace_back(MemoryEffects::Read::get(), &getPtrMutable());
214}
215
216LogicalResult MaskedLoadOp::verify() {
217 auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };
218 // Verify that the pointer type's memory space allows loads.
219 MemorySpaceAttrInterface ms = getPtr().getType().getMemorySpace();
220 DataLayout dataLayout = DataLayout::closest(*this);
221 if (!ms.isValidLoad(getResult().getType(), AtomicOrdering::not_atomic,
222 getAlignment(), &dataLayout, emitDiag))
223 return failure();
224
225 // Verify the alignment.
226 return verifyAlignment(getAlignment(), emitDiag);
227}
228
229void MaskedLoadOp::build(OpBuilder &builder, OperationState &state,
230 Type resultType, Value ptr, Value mask,
231 Value passthrough, unsigned alignment) {
232 build(builder, state, resultType, ptr, mask, passthrough,
233 alignment ? std::optional<int64_t>(alignment) : std::nullopt);
234}
235
236//===----------------------------------------------------------------------===//
237// MaskedStoreOp
238//===----------------------------------------------------------------------===//
239
240void MaskedStoreOp::getEffects(
242 &effects) {
243 // MaskedStore performs writes to the memory location specified by ptr
244 effects.emplace_back(MemoryEffects::Write::get(), &getPtrMutable());
245}
246
247LogicalResult MaskedStoreOp::verify() {
248 auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };
249 // Verify that the pointer type's memory space allows stores.
250 MemorySpaceAttrInterface ms = getPtr().getType().getMemorySpace();
251 DataLayout dataLayout = DataLayout::closest(*this);
252 if (!ms.isValidStore(getValue().getType(), AtomicOrdering::not_atomic,
253 getAlignment(), &dataLayout, emitDiag))
254 return failure();
255
256 // Verify the alignment.
257 return verifyAlignment(getAlignment(), emitDiag);
258}
259
260void MaskedStoreOp::build(OpBuilder &builder, OperationState &state,
261 Value value, Value ptr, Value mask,
262 unsigned alignment) {
263 build(builder, state, value, ptr, mask,
264 alignment ? std::optional<int64_t>(alignment) : std::nullopt);
265}
266
267//===----------------------------------------------------------------------===//
268// ScatterOp
269//===----------------------------------------------------------------------===//
270
271void ScatterOp::getEffects(
273 &effects) {
274 // Scatter performs writes to multiple memory locations specified by ptrs
275 effects.emplace_back(MemoryEffects::Write::get(), &getPtrsMutable());
276}
277
278LogicalResult ScatterOp::verify() {
279 auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };
280
281 // Verify that the pointer type's memory space allows stores.
282 MemorySpaceAttrInterface ms =
283 cast<PtrType>(getPtrs().getType().getElementType()).getMemorySpace();
284 DataLayout dataLayout = DataLayout::closest(*this);
285 if (!ms.isValidStore(getValue().getType(), AtomicOrdering::not_atomic,
286 getAlignment(), &dataLayout, emitDiag))
287 return failure();
288
289 // Verify the alignment.
290 return verifyAlignment(getAlignment(), emitDiag);
291}
292
293void ScatterOp::build(OpBuilder &builder, OperationState &state, Value value,
294 Value ptrs, Value mask, unsigned alignment) {
295 build(builder, state, value, ptrs, mask,
296 alignment ? std::optional<int64_t>(alignment) : std::nullopt);
297}
298
299//===----------------------------------------------------------------------===//
300// StoreOp
301//===----------------------------------------------------------------------===//
302
303void StoreOp::getEffects(
305 &effects) {
306 effects.emplace_back(MemoryEffects::Write::get(), &getPtrMutable());
307 // Volatile operations can have target-specific read-write effects on
308 // memory besides the one referred to by the pointer operand.
309 // Similarly, atomic operations that are monotonic or stricter cause
310 // synchronization that from a language point-of-view, are arbitrary
311 // read-writes into memory.
312 if (getVolatile_() || (getOrdering() != AtomicOrdering::not_atomic &&
313 getOrdering() != AtomicOrdering::unordered)) {
314 effects.emplace_back(MemoryEffects::Write::get());
315 effects.emplace_back(MemoryEffects::Read::get());
316 }
317}
318
319LogicalResult StoreOp::verify() {
320 auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };
321 MemorySpaceAttrInterface ms = getPtr().getType().getMemorySpace();
322 DataLayout dataLayout = DataLayout::closest(*this);
323 if (!ms.isValidStore(getValue().getType(), getOrdering(), getAlignment(),
324 &dataLayout, emitDiag))
325 return failure();
326 if (failed(verifyAlignment(getAlignment(), emitDiag)))
327 return failure();
328 return verifyAtomicMemOp(*this,
329 {AtomicOrdering::acquire, AtomicOrdering::acq_rel});
330}
331
332void StoreOp::build(OpBuilder &builder, OperationState &state, Value value,
333 Value addr, unsigned alignment, bool isVolatile,
334 bool isNonTemporal, bool isInvariantGroup,
335 AtomicOrdering ordering, StringRef syncscope) {
336 build(builder, state, value, addr,
337 alignment ? std::optional<int64_t>(alignment) : std::nullopt,
338 isVolatile, isNonTemporal, isInvariantGroup, ordering,
339 syncscope.empty() ? nullptr : builder.getStringAttr(syncscope));
340}
341
342//===----------------------------------------------------------------------===//
343// PtrAddOp
344//===----------------------------------------------------------------------===//
345
346/// Fold: ptradd ptr + 0 -> ptr
347OpFoldResult PtrAddOp::fold(FoldAdaptor adaptor) {
348 Attribute attr = adaptor.getOffset();
349 if (!attr)
350 return nullptr;
351 if (llvm::APInt value; m_ConstantInt(&value).match(attr) && value.isZero())
352 return getBase();
353 return nullptr;
354}
355
356LogicalResult PtrAddOp::inferReturnTypes(
357 MLIRContext *context, std::optional<Location> location, ValueRange operands,
358 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
359 SmallVectorImpl<Type> &inferredReturnTypes) {
360 // Get the base pointer and offset types.
361 Type baseType = operands[0].getType();
362 Type offsetType = operands[1].getType();
363
364 auto offTy = dyn_cast<ShapedType>(offsetType);
365 if (!offTy) {
366 // If the offset isn't shaped, the result is always the base type.
367 inferredReturnTypes.push_back(baseType);
368 return success();
369 }
370 auto baseTy = dyn_cast<ShapedType>(baseType);
371 if (!baseTy) {
372 // Base isn't shaped, but offset is, use the ShapedType from offset with the
373 // base pointer as element type.
374 inferredReturnTypes.push_back(offTy.clone(baseType));
375 return success();
376 }
377
378 // Both are shaped, their shape must match.
379 if (offTy.getShape() != baseTy.getShape()) {
380 if (location)
381 mlir::emitError(*location) << "shapes of base and offset must match";
382 return failure();
383 }
384
385 // Make sure they are the same kind of shaped type.
386 if (baseType.getTypeID() != offsetType.getTypeID()) {
387 if (location)
388 mlir::emitError(*location) << "the shaped containers type must match";
389 return failure();
390 }
391 inferredReturnTypes.push_back(baseType);
392 return success();
393}
394
395//===----------------------------------------------------------------------===//
396// PtrDiffOp
397//===----------------------------------------------------------------------===//
398
399LogicalResult PtrDiffOp::verify() {
400 // If the operands are not shaped early exit.
401 if (!isa<ShapedType>(getLhs().getType()))
402 return success();
403
404 // Just check the container type matches, `SameOperandsAndResultShape` handles
405 // the actual shape.
406 if (getResult().getType().getTypeID() != getLhs().getType().getTypeID()) {
407 return emitError() << "expected the result to have the same container "
408 "type as the operands when operands are shaped";
409 }
410
411 return success();
412}
413
414ptr::PtrType PtrDiffOp::getPtrType() {
415 Type lhsType = getLhs().getType();
416 if (auto shapedType = dyn_cast<ShapedType>(lhsType))
417 return cast<ptr::PtrType>(shapedType.getElementType());
418 return cast<ptr::PtrType>(lhsType);
419}
420
421Type PtrDiffOp::getIntType() {
422 Type resultType = getResult().getType();
423 if (auto shapedType = dyn_cast<ShapedType>(resultType))
424 return shapedType.getElementType();
425 return resultType;
426}
427
428//===----------------------------------------------------------------------===//
429// ToPtrOp
430//===----------------------------------------------------------------------===//
431
432OpFoldResult ToPtrOp::fold(FoldAdaptor adaptor) {
433 // Fold the pattern:
434 // %val = ptr.from_ptr %p (metadata ...)? : ptr -> type
435 // %ptr = ptr.to_ptr %val : type -> ptr
436 // To:
437 // %ptr -> %p
438 Value ptr;
439 ToPtrOp toPtr = *this;
440 while (toPtr != nullptr) {
441 auto fromPtr = toPtr.getPtr().getDefiningOp<FromPtrOp>();
442 // Cannot fold if it's not a `from_ptr` op.
443 if (!fromPtr)
444 return ptr;
445 ptr = fromPtr.getPtr();
446 // Check for chains of casts.
447 toPtr = ptr.getDefiningOp<ToPtrOp>();
448 }
449 return ptr;
450}
451
452LogicalResult ToPtrOp::verify() {
453 if (isa<PtrType>(getPtr().getType()))
454 return emitError() << "the input value cannot be of type `!ptr.ptr`";
455 if (getType().getMemorySpace() != getPtr().getType().getMemorySpace()) {
456 return emitError()
457 << "expected the input and output to have the same memory space";
458 }
459 return success();
460}
461
462//===----------------------------------------------------------------------===//
463// TypeOffsetOp
464//===----------------------------------------------------------------------===//
465
466llvm::TypeSize TypeOffsetOp::getTypeSize(std::optional<DataLayout> layout) {
467 if (layout)
468 return layout->getTypeSize(getElementType());
470 return dl.getTypeSize(getElementType());
471}
472
473//===----------------------------------------------------------------------===//
474// Pointer API.
475//===----------------------------------------------------------------------===//
476
477#include "mlir/Dialect/Ptr/IR/PtrOpsDialect.cpp.inc"
478
479#define GET_ATTRDEF_CLASSES
480#include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.cpp.inc"
481
482#include "mlir/Dialect/Ptr/IR/PtrOpsEnums.cpp.inc"
483
484#define GET_TYPEDEF_CLASSES
485#include "mlir/Dialect/Ptr/IR/PtrOpsTypes.cpp.inc"
486
487#define GET_OP_CLASSES
488#include "mlir/Dialect/Ptr/IR/PtrOps.cpp.inc"
return success()
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static InFlightDiagnostic emitDiag(Location location, DiagnosticSeverity severity, const Twine &message)
Helper function used to emit a diagnostic with an optionally empty twine message.
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyAtomicMemOp(OpTy memOp, ArrayRef< AtomicOrdering > unsupportedOrderings)
Verifies the attributes and the type of atomic memory access operations.
static LogicalResult verifyAlignment(std::optional< int64_t > alignment, function_ref< InFlightDiagnostic()> emitError)
Verifies that the alignment attribute is a power of 2 if present.
Attributes are known-constant values of operations.
Definition Attributes.h:25
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:262
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
This class represents a single result from folding an operation.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class provides an abstraction over the different types of ranges over Regions.
Definition Region.h:346
This class represents a specific instance of an effect.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
TypeID getTypeID()
Return a unique identifier for the concrete type.
Definition Types.h:101
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
This represents an operation in an abstracted form, suitable for use with the builder APIs.