14 #include "llvm/Support/Debug.h"
16 #define DEBUG_TYPE "xegpu"
24 for (
size_t i = 0; i < trans.size(); i++)
25 shape[i] = old[trans[i]];
29 static std::string
makeString(T array,
bool breakline =
false) {
32 llvm::raw_string_ostream os(buf);
34 for (
size_t i = 1; i < array.size(); i++) {
35 os << array[i - 1] <<
", ";
39 os << array.back() <<
"]";
46 if (
auto ty = llvm::dyn_cast<ShapedType>(type))
55 if (
auto ty = llvm::dyn_cast<ShapedType>(type))
63 auto kind = attr.getValue();
64 return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
65 kind == CachePolicy::STREAMING || kind == CachePolicy::READ_INVALIDATE;
71 auto kind = attr.getValue();
72 return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
73 kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
82 [[maybe_unused]]
auto ty = source.getType();
83 assert(ty.hasStaticShape() && offsets.size() == (
size_t)ty.getRank());
89 build(builder, state, tdesc, source, dynamicOffsets ,
96 void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
97 Type tdesc, TypedValue<IntegerType> source,
101 assert(shape.size() && offsets.size() && strides.size() &&
102 shape.size() == strides.size() && shape.size() == offsets.size());
115 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
116 auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
117 auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
119 build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
120 dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
124 auto rank = (int64_t)getMixedOffsets().size();
125 bool invalidRank = (rank != 2);
126 bool invalidElemTy =
false;
130 auto memrefTy = dyn_cast<MemRefType>(getSourceType());
132 invalidRank |= (memrefTy.getRank() != rank);
137 invalidRank = (getType().getRank() != rank);
144 "Expecting the rank of shape, strides, offsets, "
145 "source memref type (if source is a memref) and TensorDesc "
146 "should match with each other. They currenlty are 2D.");
149 return emitOpError(
"TensorDesc should have the same element "
150 "type with the source if it is a memref.\n");
152 if (getType().getScattered())
153 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
162 auto tdescTy = getTensorDescType();
163 if (tdescTy.getScattered())
164 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
167 return emitOpError(
"invlid l1_hint: ") << getL1HintAttr();
170 return emitOpError(
"invlid l2_hint: ") << getL2HintAttr();
173 return emitOpError(
"invlid l3_hint: ") << getL3HintAttr();
182 auto tdescTy = getTensorDescType();
183 auto valueTy = getType();
185 if (tdescTy.getRank() != 2)
186 return emitOpError(
"Expecting a 2D TensorDesc.\n");
188 if (tdescTy.getScattered())
189 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
192 return emitOpError(
"Invalid result, it should be a VectorType.\n");
195 return emitOpError(
"invlid l1_hint: ") << getL1HintAttr();
198 return emitOpError(
"invlid l2_hint: ") << getL2HintAttr();
201 return emitOpError(
"invlid l3_hint: ") << getL3HintAttr();
203 auto array_len = tdescTy.getArrayLength();
207 if (getTranspose()) {
208 auto trans = getTranspose().value();
209 if (tdescShape.size() >= trans.size())
212 emitWarning(
"Invalid transpose attr. It is ignored.");
216 auto axis = getVnniAxis().value();
217 auto vnni_factor = valueShape.back();
218 tdescShape[axis] /= vnni_factor;
219 tdescShape.push_back(vnni_factor);
223 auto it = tdescShape.begin();
224 tdescShape.insert(it, array_len);
227 if (tdescShape != valueShape)
228 return emitOpError() <<
"Result shape doesn't match TensorDesc shape."
229 <<
"The expected shape is " <<
makeString(tdescShape)
230 <<
". But the given shape is "
239 auto dstTy = getTensorDescType();
240 auto valTy = getValueType();
242 if (dstTy.getRank() != 2)
243 return emitOpError(
"Expecting a 2D TensorDesc.\n");
245 if (dstTy.getScattered())
246 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
249 return emitOpError(
"Exepcting a VectorType result.\n");
252 return emitOpError(
"invlid l1_hint: ") << getL1HintAttr();
255 return emitOpError(
"invlid l2_hint: ") << getL2HintAttr();
258 return emitOpError(
"invlid l3_hint: ") << getL3HintAttr();
267 auto ty = getTensorDescType();
268 if (ty.getScattered())
269 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
272 if (ty.getRank() != (int64_t)getNumOffsets()) {
273 return emitOpError(
"Invalid number of offsets.");
281 void CreateDescOp::build(OpBuilder &builder, OperationState &state,
282 TensorDescType TensorDesc, Value source,
284 uint32_t chunk_size) {
288 build(builder, state, TensorDesc, source, dynamicOffsets, staticOffsets,
293 auto tdescTy = getTensorDescType();
294 auto chunkSize = getChunkSize();
298 "Expecting the source is a 1D memref or pointer (uint64_t).");
300 if (!tdescTy.getScattered())
301 return emitOpError(
"Expects a scattered TensorDesc.\n");
303 SmallVector<int64_t> shape({(int64_t)getNumOffsets()});
305 shape.push_back(chunkSize);
308 if (shape != tdescShape)
309 return emitOpError(
"Incorrect TensorDesc shape. ")
310 <<
"Expected is " <<
makeString(shape) <<
"\n";
319 auto tdescTy = getTensorDescType();
320 if (!tdescTy.getScattered())
321 return emitOpError(
"Expects a scattered TensorDesc.\n");
324 return emitOpError(
"invlid l1_hint: ") << getL1HintAttr();
327 return emitOpError(
"invlid l2_hint: ") << getL2HintAttr();
330 return emitOpError(
"invlid l3_hint: ") << getL3HintAttr();
339 auto tdescTy = getTensorDescType();
340 auto maskTy = getMaskType();
341 auto valueTy = getValueType();
343 if (!tdescTy.getScattered())
344 return emitOpError(
"Expects a scattered TensorDesc.\n");
347 return emitOpError(
"invlid l1_hint: ") << getL1HintAttr();
350 return emitOpError(
"invlid l2_hint: ") << getL2HintAttr();
353 return emitOpError(
"invlid l3_hint: ") << getL3HintAttr();
355 auto tdescElemTy = tdescTy.getElementType();
357 if (tdescElemTy != valueElemTy)
359 "Value should have the same element type as TensorDesc.");
365 if (tdescShape[0] != maskShape[0])
366 return emitOpError(
"dim-0 of the Mask and TensorDesc should be the same.");
368 if (getTransposeAttr()) {
369 auto trans = getTranspose().value();
370 if (tdescShape.size() < trans.size())
371 emitWarning(
"Invalid transpose attr. It is ignored.");
376 if (valueShape != tdescShape)
377 return emitOpError(
"Unexpected result shape")
378 <<
"(Expected shape: " <<
makeString(tdescShape)
379 <<
", Given shape: " <<
makeString(valueShape) <<
").\n";
388 auto tdescTy = getTensorDescType();
389 if (!tdescTy.getScattered())
390 return emitOpError(
"Expects a scattered TensorDesc.\n");
393 return emitOpError(
"invlid l1_hint: ") << getL1HintAttr();
396 return emitOpError(
"invlid l2_hint: ") << getL2HintAttr();
399 return emitOpError(
"invlid l3_hint: ") << getL3HintAttr();
401 auto maskTy = getMaskType();
404 if (tdescShape[0] != maskShape[0])
405 return emitOpError(
"dim-0 of the Mask and TensorDesc should be the same.");
413 int64_t lhsRank = getLhsType().getRank();
414 int64_t rhsRank = getRhsType().getRank();
416 if (lhsRank != rhsRank || lhsRank != 3)
418 "lhs and rhs rank does not match for dpas op, or their rank is not 3.");
420 if (getAcc() && getAccType() != getResultType())
421 return emitOpError(
"Accumulator and Result for dpas op should have the "
422 "same type (both shape and element type).");
424 auto lhsShape = getLhsType().getShape();
425 auto rhsShape = getRhsType().getShape();
426 if (lhsShape[1] != rhsShape[0] || lhsShape[2] != rhsShape[2])
427 return emitOpError(
"K-dimension or vnni-factor mismatch.");
435 #include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc>
436 #define GET_OP_CLASSES
437 #include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
This class helps build Operations.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static std::string makeString(T array, bool breakline=false)
static int64_t getRankOf(Value val)
static bool isWriteHintOrNone(const CachePolicyAttr &attr)
static bool isReadHintOrNone(const CachePolicyAttr &attr)
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
static SmallVector< int64_t > getShapeOf(Type type)
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This represents an operation in an abstracted form, suitable for use with the builder APIs.