18 #include "llvm/ADT/APInt.h"
24 Value input, RingAttr ring) {
26 auto bitWidth = tensorType.getElementTypeBitWidth();
27 APInt cmod(1 + bitWidth, 1);
28 cmod = cmod << bitWidth;
30 build(builder, result, resultType, input);
35 RingAttr ring = getOutput().getType().getRing();
36 IntPolynomialAttr polyMod = ring.getPolynomialModulus();
38 unsigned polyDegree = polyMod.getPolynomial().getDegree();
39 bool compatible = tensorShape.size() == 1 && tensorShape[0] <= polyDegree;
42 <<
"input type " << getInput().getType()
43 <<
" does not match output type "
44 << getOutput().getType();
46 <<
"the input type must be a tensor of shape [d] where d "
47 "is at most the degree of the polynomialModulus of "
48 "the output type's ring attribute";
53 unsigned inputBitWidth = getInput().getType().getElementTypeBitWidth();
54 if (inputBitWidth > ring.getCoefficientType().getIntOrFloatBitWidth()) {
56 <<
"input tensor element type "
57 << getInput().getType().getElementType()
58 <<
" is too large to fit in the coefficients of "
59 << getOutput().getType();
60 diag.attachNote() <<
"the input tensor's elements must be rescaled"
61 " to fit before using from_tensor";
70 IntPolynomialAttr polyMod =
71 getInput().getType().getRing().getPolynomialModulus();
73 unsigned polyDegree = polyMod.getPolynomial().getDegree();
74 bool compatible = tensorShape.size() == 1 && tensorShape[0] == polyDegree;
80 <<
"input type " << getInput().getType()
81 <<
" does not match output type "
82 << getOutput().getType();
84 <<
"the output type must be a tensor of shape [d] where d "
85 "is at most the degree of the polynomialModulus of "
86 "the input type's ring attribute";
94 Type argType = getPolynomial().getType();
95 PolynomialType polyType;
97 if (
auto shapedPolyType = dyn_cast<ShapedType>(argType)) {
98 polyType = cast<PolynomialType>(shapedPolyType.getElementType());
100 polyType = cast<PolynomialType>(argType);
103 Type coefficientType = polyType.getRing().getCoefficientType();
105 if (coefficientType != getScalar().
getType())
106 return emitOpError() <<
"polynomial coefficient type " << coefficientType
107 <<
" does not match scalar type "
108 << getScalar().getType();
118 unsigned requiredBitWidth =
119 std::max(root.getActiveBits() * 2, cmod.getActiveBits() * 2);
120 APInt r = APInt(root).zextOrTrunc(requiredBitWidth);
121 APInt cmodExt = APInt(cmod).zextOrTrunc(requiredBitWidth);
122 assert(r.ule(cmodExt) &&
"root must be less than cmod");
123 uint64_t upperBound = n.getZExtValue();
126 for (
size_t k = 1; k < upperBound; k++) {
129 a = (a * r).urem(cmodExt);
137 RankedTensorType tensorType,
138 std::optional<PrimitiveRootAttr> root) {
139 Attribute encoding = tensorType.getEncoding();
142 <<
"expects a ring encoding to be provided to the tensor";
144 auto encodedRing = dyn_cast<RingAttr>(encoding);
147 <<
"the provided tensor encoding is not a ring attribute";
150 if (encodedRing != ring) {
152 <<
"encoded ring type " << encodedRing
153 <<
" is not equivalent to the polynomial ring " << ring;
156 unsigned polyDegree = ring.getPolynomialModulus().getPolynomial().getDegree();
158 bool compatible = tensorShape.size() == 1 && tensorShape[0] == polyDegree;
161 <<
"tensor type " << tensorType
162 <<
" does not match output type " << ring;
163 diag.attachNote() <<
"the tensor must have shape [d] where d "
164 "is exactly the degree of the polynomialModulus of "
165 "the polynomial type's ring attribute";
169 if (root.has_value()) {
170 APInt rootValue = root.value().getValue().getValue();
171 APInt rootDegree = root.value().getDegree().getValue();
172 APInt cmod = ring.getCoefficientModulus().getValue();
175 <<
"provided root " << rootValue.getZExtValue()
176 <<
" is not a primitive root "
177 <<
"of unity mod " << cmod.getZExtValue()
178 <<
", with the specified degree " << rootDegree.getZExtValue();
187 getOutput().
getType(), getRoot());
192 getInput().
getType(), getRoot());
226 TypedIntPolynomialAttr typedIntPolyAttr;
229 typedIntPolyAttr,
"value", result.
attributes);
231 result.
addTypes(typedIntPolyAttr.getType());
235 TypedFloatPolynomialAttr typedFloatPolyAttr;
237 typedFloatPolyAttr,
"value", result.
attributes);
239 result.
addTypes(typedFloatPolyAttr.getType());
248 if (
auto intPoly = dyn_cast<TypedIntPolynomialAttr>(getValue())) {
250 intPoly.getValue().print(p);
251 }
else if (
auto floatPoly = dyn_cast<TypedFloatPolynomialAttr>(getValue())) {
253 floatPoly.getValue().print(p);
255 assert(
false &&
"unexpected attribute type");
261 LogicalResult ConstantOp::inferReturnTypes(
262 MLIRContext *context, std::optional<mlir::Location> location,
263 ConstantOp::Adaptor adaptor,
266 if (
auto intPoly = dyn_cast<TypedIntPolynomialAttr>(operand)) {
267 inferredReturnTypes.push_back(intPoly.getType());
268 }
else if (
auto floatPoly = dyn_cast<TypedFloatPolynomialAttr>(operand)) {
269 inferredReturnTypes.push_back(floatPoly.getType());
271 assert(
false &&
"unexpected attribute type");
282 #include "PolynomialCanonicalization.inc"
287 results.
add<SubAsAdd>(context);
292 results.
add<NTTAfterINTT>(context);
297 results.
add<INTTAfterNTT>(context);
static std::string diag(const llvm::Value &value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
bool isPrimitiveNthRootOfUnity(const APInt &root, const APInt &n, const APInt &cmod)
Test if a value is a primitive nth root of unity modulo cmod.
static LogicalResult verifyNTTOp(Operation *op, RingAttr ring, RankedTensorType tensorType, std::optional< PrimitiveRootAttr > root)
Verify that the types involved in an NTT or INTT operation are compatible.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual OptionalParseResult parseOptionalAttribute(Attribute &result, Type type={})=0
Parse an arbitrary optional attribute of a given type and return it in result.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printType(Type type)
Attributes are known-constant values of operations.
MLIRContext * getContext() const
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)